Skip to content

Commit 5de2524

Browse files
authored
scatter CI failures (#2925)
1 parent 65c4951 commit 5de2524

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,10 @@ def slice_scatter_decomposition(
207207
if i != dim:
208208
index_tensor_shape.append(src_each_dim)
209209
for index in range(start, end, step):
210-
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.long))
210+
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.int64))
211211
index_tensor = torch.stack(cat_tensors, dim).cuda()
212-
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
212+
index_tensor_64 = index_tensor.to(torch.int64)
213+
output_tensor = torch.scatter(input_tensor, dim, index_tensor_64, src_tensor)
213214
return output_tensor
214215

215216

0 commit comments

Comments
 (0)