Skip to content

Commit 0701aab

Browse files
authored
[Relax][PyTorch]: Fix the sqrt operation requires float dtype but receives int64 in attention scaling (#18454)
This PR is trying to fix issues #18443. --------- Co-authored-by: cchung100m <[email protected]>
1 parent fd57110 commit 0701aab

File tree

4 files changed

+106
-4
lines changed

4 files changed

+106
-4
lines changed

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,26 @@ def _reciprocal(self, node: fx.Node) -> relax.Var:
6464
x = self.env[node.args[0]]
6565
return self.block_builder.emit(relax.op.divide(relax.const(1.0, x.struct_info.dtype), x))
6666

67+
def _sqrt(self, node: fx.Node) -> relax.Var:
68+
x = self.env[node.args[0]]
69+
dtype = x.struct_info.dtype
70+
71+
# Check if input is integer type and convert to float32 if needed
72+
if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"):
73+
x = self.block_builder.emit(relax.op.astype(x, "float32"))
74+
75+
return self.block_builder.emit(relax.op.sqrt(x))
76+
77+
def _rsqrt(self, node: fx.Node) -> relax.Var:
78+
x = self.env[node.args[0]]
79+
dtype = x.struct_info.dtype
80+
81+
# Check if input is integer type and convert to float32 if needed
82+
if dtype in ("int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"):
83+
x = self.block_builder.emit(relax.op.astype(x, "float32"))
84+
85+
return self.block_builder.emit(relax.op.rsqrt(x))
86+
6787
########## Neural Network ##########
6888

6989
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:
@@ -919,7 +939,7 @@ def create_convert_map(
919939
"relu6.default": self._unary_op(relax.op.nn.relu6),
920940
"relu6_.default": self._unary_op(relax.op.nn.relu6),
921941
"round.default": self._round,
922-
"rsqrt.default": self._unary_op(relax.op.rsqrt),
942+
"rsqrt.default": self._rsqrt,
923943
"scalar_tensor.default": self._scalar_tensor,
924944
"rsub.Tensor": self._rsub,
925945
"rsub.Scalar": self._rsub,
@@ -935,7 +955,7 @@ def create_convert_map(
935955
"softplus.default": self._softplus,
936956
"softshrink.default": self._softshrink,
937957
"softsign.default": self._softsign,
938-
"sqrt.default": self._unary_op(relax.op.sqrt),
958+
"sqrt.default": self._sqrt,
939959
"square.default": self._unary_op(relax.op.square),
940960
"tan.default": self._unary_op(relax.op.tan),
941961
"tanh.default": self._unary_op(relax.op.tanh),

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,26 @@ def _log1p(self, node: fx.Node) -> relax.Var:
9696
one = relax.const(1, x.struct_info.dtype)
9797
return self.block_builder.emit(relax.op.log(relax.op.add(x, one)))
9898

99+
def _sqrt(self, node: fx.Node) -> relax.Var:
100+
x = self.env[node.args[0]]
101+
dtype = x.struct_info.dtype
102+
103+
# Check if input is integer type and convert to float32 if needed
104+
if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]:
105+
x = self.block_builder.emit(relax.op.astype(x, "float32"))
106+
107+
return self.block_builder.emit(relax.op.sqrt(x))
108+
109+
def _rsqrt(self, node: fx.Node) -> relax.Var:
110+
x = self.env[node.args[0]]
111+
dtype = x.struct_info.dtype
112+
113+
# Check if input is integer type and convert to float32 if needed
114+
if dtype in ["int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"]:
115+
x = self.block_builder.emit(relax.op.astype(x, "float32"))
116+
117+
return self.block_builder.emit(relax.op.rsqrt(x))
118+
99119
def _log_softmax_module(self, node: fx.Node) -> relax.Var:
100120
x = self.env[node.args[0]]
101121
module = self.named_modules[node.target]
@@ -825,7 +845,7 @@ def create_convert_map(
825845
"relu": self._unary_op(relax.op.nn.relu),
826846
"relu6": self._unary_op(relax.op.nn.relu6),
827847
"round": self._round,
828-
"rsqrt": self._unary_op(relax.op.rsqrt),
848+
"rsqrt": self._rsqrt,
829849
"selu": self._unary_op(relax.op.nn.selu),
830850
"sigmoid": self._unary_op(relax.op.sigmoid),
831851
"sign": self._unary_op(relax.op.sign),
@@ -834,7 +854,7 @@ def create_convert_map(
834854
"sinh": self._unary_op(relax.op.sinh),
835855
"softmax": self._softmax,
836856
"softplus": self._softplus,
837-
"sqrt": self._unary_op(relax.op.sqrt),
857+
"sqrt": self._sqrt,
838858
"square": self._unary_op(relax.op.square),
839859
"tan": self._unary_op(relax.op.tan),
840860
"tanh": self._unary_op(relax.op.tanh),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,47 @@ def main(
126126
verify_model(UnaryOp(), example_args, {}, expected, run_ep_decomposition=True)
127127

128128

129+
def test_sqrt_integer_input():
130+
"""Test that sqrt operation works with integer tensors by auto-converting to float."""
131+
example_args = (torch.tensor([[4, 9, 16, 25]], dtype=torch.int64),)
132+
133+
class SqrtIntModel(Module):
134+
def forward(self, input):
135+
return torch.sqrt(input)
136+
137+
@tvm.script.ir_module
138+
class expected_int64:
139+
@R.function
140+
def main(
141+
input_1: R.Tensor((1, 4), dtype="int64")
142+
) -> R.Tuple(R.Tensor((1, 4), dtype="float32")):
143+
with R.dataflow():
144+
lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, dtype="float32")
145+
lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv)
146+
gv: R.Tuple(R.Tensor((1, 4), dtype="float32")) = (lv1,)
147+
R.output(gv)
148+
return gv
149+
150+
verify_model(SqrtIntModel(), example_args, {}, expected_int64, run_ep_decomposition=True)
151+
152+
example_args_int32 = (torch.tensor([[1, 4, 9]], dtype=torch.int32),)
153+
154+
@tvm.script.ir_module
155+
class expected_int32:
156+
@R.function
157+
def main(
158+
input_1: R.Tensor((1, 3), dtype="int32")
159+
) -> R.Tuple(R.Tensor((1, 3), dtype="float32")):
160+
with R.dataflow():
161+
lv: R.Tensor((1, 3), dtype="float32") = R.astype(input_1, dtype="float32")
162+
lv1: R.Tensor((1, 3), dtype="float32") = R.sqrt(lv)
163+
gv: R.Tuple(R.Tensor((1, 3), dtype="float32")) = (lv1,)
164+
R.output(gv)
165+
return gv
166+
167+
verify_model(SqrtIntModel(), example_args_int32, {}, expected_int32, run_ep_decomposition=True)
168+
169+
129170
def test_extended_unary_ops():
130171
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
131172

tests/python/relax/test_frontend_from_fx.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2749,6 +2749,27 @@ def main(
27492749
verify_model(Unary(), input_info, {}, expected_unary)
27502750

27512751

2752+
def test_sqrt_integer_input_fx():
2753+
input_info = [([1, 4], "int64")]
2754+
2755+
class SqrtIntModel(Module):
2756+
def forward(self, input):
2757+
return torch.sqrt(input)
2758+
2759+
@tvm.script.ir_module
2760+
class expected:
2761+
@R.function
2762+
def main(input_1: R.Tensor((1, 4), dtype="int64")) -> R.Tensor((1, 4), dtype="float32"):
2763+
with R.dataflow():
2764+
lv: R.Tensor((1, 4), dtype="float32") = R.astype(input_1, dtype="float32")
2765+
lv1: R.Tensor((1, 4), dtype="float32") = R.sqrt(lv)
2766+
gv: R.Tensor((1, 4), dtype="float32") = lv1
2767+
R.output(gv)
2768+
return gv
2769+
2770+
verify_model(SqrtIntModel(), input_info, {}, expected)
2771+
2772+
27522773
operator_bool_unary = [
27532774
(torch.isnan, R.isnan),
27542775
(torch.isinf, R.isinf),

0 commit comments

Comments
 (0)