@@ -128,18 +128,24 @@ def new_init(self, *args, **kwargs):
128128 device = framework ._get_paddle_place (device )
129129 if len (args ) == 0 and len (kwargs ) == 0 : # case 1, 2
130130 original_init (
131- self , paddle .empty (shape = [0 ], dtype = 'float32' ), place = device
131+ self ,
132+ paddle .empty (shape = [0 ], dtype = 'float32' , device = device ),
133+ place = device ,
132134 )
133135 return
134136 if 'data' in kwargs : # case 7,8
135137 data = kwargs .pop ('data' )
136138 original_init (
137- self , paddle .tensor (data , dtype = 'float32' ), place = device
139+ self ,
140+ paddle .tensor (data , dtype = 'float32' , device = device ),
141+ place = device ,
138142 )
139143 elif len (args ) == 1 and isinstance (args [0 ], (list , tuple )):
140144 # case 5, 6
141145 original_init (
142- self , paddle .tensor (args [0 ], dtype = 'float32' ), place = device
146+ self ,
147+ paddle .tensor (args [0 ], dtype = 'float32' , device = device ),
148+ place = device ,
143149 )
144150 elif (
145151 builtins .all (isinstance (arg , int ) for arg in args )
@@ -148,7 +154,7 @@ def new_init(self, *args, **kwargs):
148154 # case 3, 4
149155 original_init (
150156 self ,
151- paddle .empty (shape = list (args ), dtype = 'float32' ),
157+ paddle .empty (shape = list (args ), dtype = 'float32' , device = device ),
152158 place = device ,
153159 )
154160 else :
0 commit comments