Skip to content

Commit eff9138

Browse files
committed
Update unpacked pattern to use instead of due to conversion issue when input data types are mismatched
1 parent e94c64a commit eff9138

File tree

1 file changed

+1
-12
lines changed

1 file changed

+1
-12
lines changed

core/lowering/passes/device_casting.cpp

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,7 @@ void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph) {
2323
%false: bool = prim::Constant[value=0]()
2424
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
2525
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
26-
27-
# Value is cast to type of original tensor and value defaults to float
28-
%is_float: bool = aten::is_floating_point(%self)
29-
%out: Tensor = prim::If(%is_float)
30-
block0():
31-
%no_cast: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
32-
-> (%no_cast)
33-
block1():
34-
%value_int: int = aten::Int(%value)
35-
%casted_int: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value_int)
36-
-> (%casted_int)
37-
26+
%out: Tensor = aten::masked_fill_(%self_cuda, %mask_cuda, %value)
3827
return (%out))IR";
3928

4029
torch::jit::SubgraphRewriter masked_fill_rewriter;

0 commit comments

Comments
 (0)