Skip to content

Commit 3a3d62a

Browse files
authored
Fixing slice scatter and select scatter decomposition (#3093)
1 parent ffa4f64 commit 3a3d62a

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def slice_scatter_decomposition(
190190
step: Optional[int] = None,
191191
) -> torch.Tensor:
192192
dim_size = input_tensor.shape[dim]
193+
device_input_tensor = input_tensor.device
193194
start = get_positive_dim(start, input_tensor.shape[dim])
194195
if end is None:
195196
end = dim_size
@@ -216,7 +217,8 @@ def slice_scatter_decomposition(
216217
index_tensor_shape.append(src_each_dim)
217218
for index in range(start, end, step):
218219
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64))
219-
index_tensor = torch.stack(cat_tensors, dim).to(input_tensor.device)
220+
index_tensor = torch.stack(cat_tensors, dim)
221+
index_tensor = index_tensor.to(device_input_tensor)
220222
index_tensor_64 = index_tensor.to(torch.int64)
221223
output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor)
222224
return output_tensor

0 commit comments

Comments
 (0)