Skip to content

Commit 2ba7e79

Browse files
authored
fix the bug in constrcut Tensor with diference place (#75022)
1 parent f6e29de commit 2ba7e79

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

python/paddle/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,18 +128,24 @@ def new_init(self, *args, **kwargs):
128128
device = framework._get_paddle_place(device)
129129
if len(args) == 0 and len(kwargs) == 0: # case 1, 2
130130
original_init(
131-
self, paddle.empty(shape=[0], dtype='float32'), place=device
131+
self,
132+
paddle.empty(shape=[0], dtype='float32', device=device),
133+
place=device,
132134
)
133135
return
134136
if 'data' in kwargs: # case 7,8
135137
data = kwargs.pop('data')
136138
original_init(
137-
self, paddle.tensor(data, dtype='float32'), place=device
139+
self,
140+
paddle.tensor(data, dtype='float32', device=device),
141+
place=device,
138142
)
139143
elif len(args) == 1 and isinstance(args[0], (list, tuple)):
140144
# case 5, 6
141145
original_init(
142-
self, paddle.tensor(args[0], dtype='float32'), place=device
146+
self,
147+
paddle.tensor(args[0], dtype='float32', device=device),
148+
place=device,
143149
)
144150
elif (
145151
builtins.all(isinstance(arg, int) for arg in args)
@@ -148,7 +154,7 @@ def new_init(self, *args, **kwargs):
148154
# case 3, 4
149155
original_init(
150156
self,
151-
paddle.empty(shape=list(args), dtype='float32'),
157+
paddle.empty(shape=list(args), dtype='float32', device=device),
152158
place=device,
153159
)
154160
else:

0 commit comments

Comments
 (0)