Skip to content

Commit 0225d67

Browse files
authored
[Relax][PyTorch] Fix MultiheadAttention complie (#18459)
## Related Issus closes #18440 ## Why - PyTorch `masked_fill` / `full_like` accept inf or nan and TVM couldn’t handle these values when the tensor dtype was not float, which caused wrong behavior or errors. ## How - If `fill_value` is inf or nan and the tensor dtype is not float → convert the fill to float32. - For masked_fill → Create a float values tensor with full_like. - Cast input to float if needed. - In TOPI → Reject creating full with inf/nan on non-float dtypes.
1 parent a9955e5 commit 0225d67

File tree

2 files changed

+47
-5
lines changed

2 files changed

+47
-5
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,8 +2085,16 @@ def _full(self, node: fx.Node) -> relax.Var:
20852085

20862086
def _full_like(self, node: fx.Node) -> relax.Var:
20872087
x = self.env[node.args[0]]
2088-
fill_value = relax.const(node.args[1])
2089-
return self.block_builder.emit(relax.op.full_like(x, fill_value))
2088+
value = node.args[1]
2089+
fill_value = relax.const(value)
2090+
2091+
x_dtype = x.struct_info.dtype
2092+
fill_dtype = None
2093+
if isinstance(value, (int, float)) and (math.isinf(value) or math.isnan(value)):
2094+
if not ("float" in x_dtype or "bfloat16" in x_dtype):
2095+
fill_dtype = "float32"
2096+
2097+
return self.block_builder.emit(relax.op.full_like(x, fill_value, dtype=fill_dtype))
20902098

20912099
def _index_select(self, node: fx.Node) -> relax.Var:
20922100
x = self.env[node.args[0]]
@@ -2099,7 +2107,19 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
20992107
mask = self.env[node.args[1]]
21002108
value = node.args[2]
21012109
rx_value = relax.const(value)
2102-
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
2110+
2111+
x_dtype = x.struct_info.dtype
2112+
fill_dtype = None
2113+
if isinstance(value, (int, float)) and (math.isinf(value) or math.isnan(value)):
2114+
if not ("float" in x_dtype or "bfloat16" in x_dtype):
2115+
fill_dtype = "float32"
2116+
2117+
values = self.block_builder.emit(relax.op.full_like(x, rx_value, dtype=fill_dtype))
2118+
2119+
# Cast x to match values dtype if necessary
2120+
if fill_dtype is not None:
2121+
x = self.block_builder.emit(relax.op.astype(x, fill_dtype))
2122+
21032123
output = self.block_builder.emit(relax.op.where(mask, values, x))
21042124
self.env[node.args[0]] = output
21052125
return output
@@ -2130,8 +2150,21 @@ def _linspace(self, node: fx.Node) -> relax.Var:
21302150
def _masked_fill(self, node: fx.Node) -> relax.Var:
21312151
x = self.env[node.args[0]]
21322152
mask = self.env[node.args[1]]
2133-
rx_value = relax.const(node.args[2])
2134-
values = self.block_builder.emit(relax.op.full_like(x, rx_value))
2153+
value = node.args[2]
2154+
rx_value = relax.const(value)
2155+
2156+
x_dtype = x.struct_info.dtype
2157+
fill_dtype = None
2158+
if isinstance(value, (int, float)) and (math.isinf(value) or math.isnan(value)):
2159+
if not ("float" in x_dtype or "bfloat16" in x_dtype):
2160+
fill_dtype = "float32"
2161+
2162+
values = self.block_builder.emit(relax.op.full_like(x, rx_value, dtype=fill_dtype))
2163+
2164+
# Cast x to match values dtype if necessary
2165+
if fill_dtype is not None:
2166+
x = self.block_builder.emit(relax.op.astype(x, fill_dtype))
2167+
21352168
return self.block_builder.emit(relax.op.where(mask, values, x))
21362169

21372170
def _new_ones(self, node: fx.Node) -> relax.Var:

python/tvm/topi/tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
# pylint: disable=invalid-name,consider-using-enumerate,unused-argument,len-as-condition
1818
"""Elementwise operators"""
1919

20+
import math as _math
21+
2022
from typing import Optional
2123

2224
from tvm import te
@@ -57,6 +59,13 @@ def full(shape, dtype, fill_value):
5759
y : tvm.te.Tensor
5860
The result.
5961
"""
62+
63+
if isinstance(fill_value, (int, float)) and (
64+
_math.isinf(fill_value) or _math.isnan(fill_value)
65+
):
66+
if not ("float" in dtype or "bfloat16" in dtype):
67+
raise ValueError("Infinite and NaN require a floating-point dtype.")
68+
6069
return cpp.full(shape, dtype, fill_value)
6170

6271

0 commit comments

Comments
 (0)