@@ -128,18 +128,24 @@ def new_init(self, *args, **kwargs):
128
128
device = framework ._get_paddle_place (device )
129
129
if len (args ) == 0 and len (kwargs ) == 0 : # case 1, 2
130
130
original_init (
131
- self , paddle .empty (shape = [0 ], dtype = 'float32' ), place = device
131
+ self ,
132
+ paddle .empty (shape = [0 ], dtype = 'float32' , device = device ),
133
+ place = device ,
132
134
)
133
135
return
134
136
if 'data' in kwargs : # case 7,8
135
137
data = kwargs .pop ('data' )
136
138
original_init (
137
- self , paddle .tensor (data , dtype = 'float32' ), place = device
139
+ self ,
140
+ paddle .tensor (data , dtype = 'float32' , device = device ),
141
+ place = device ,
138
142
)
139
143
elif len (args ) == 1 and isinstance (args [0 ], (list , tuple )):
140
144
# case 5, 6
141
145
original_init (
142
- self , paddle .tensor (args [0 ], dtype = 'float32' ), place = device
146
+ self ,
147
+ paddle .tensor (args [0 ], dtype = 'float32' , device = device ),
148
+ place = device ,
143
149
)
144
150
elif (
145
151
builtins .all (isinstance (arg , int ) for arg in args )
@@ -148,7 +154,7 @@ def new_init(self, *args, **kwargs):
148
154
# case 3, 4
149
155
original_init (
150
156
self ,
151
- paddle .empty (shape = list (args ), dtype = 'float32' ),
157
+ paddle .empty (shape = list (args ), dtype = 'float32' , device = device ),
152
158
place = device ,
153
159
)
154
160
else :
0 commit comments