Skip to content

Commit 01e88aa

Browse files
authored
Merge branch 'main' into justinchu/should-fold
2 parents 589ff9d + 88b03d8 commit 01e88aa

File tree

6 files changed

+94
-50
lines changed

6 files changed

+94
-50
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3688,23 +3688,27 @@ def python_math_floor(self: TFloat) -> TInt:
36883688

36893689

36903690
@torch_op("aten::floor_divide", trace_only=True)
3691-
def aten_floor_divide(self: TFloat, other: TFloat) -> TFloat:
3691+
def aten_floor_divide(self: TTensor, other: TTensor) -> TTensor:
36923692
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
36933693

3694-
return op.Floor(op.Div(self, other))
3694+
if self.dtype.is_floating_point():
3695+
return op.Floor(op.Div(self, other))
36953696

3697+
assert self.dtype.is_integer()
36963698

3697-
@torch_op("aten::floor_divide", trace_only=True)
3698-
def aten_floor_divide_int(self: TInt, other: TInt) -> TInt:
3699-
"""floor_divide(Tensor self, Tensor other) -> Tensor"""
3699+
if not self.dtype.is_signed():
3700+
return op.Div(self, other)
37003701

3701-
# TODO(justinchuby): This can be simplified if we can constrain the
3702-
# inputs to be positive integers. Consider how we can embed constraints in the model.
3703-
dtype = self.dtype
3704-
self = op.Cast(self, to=FLOAT.dtype)
3705-
other = op.Cast(other, to=FLOAT.dtype)
3706-
result = op.Floor(op.Div(self, other))
3707-
return op.Cast(result, to=dtype)
3702+
# Convert truncation to flooring
3703+
# Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70
3704+
# offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
3705+
# return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
3706+
offset = op.And(
3707+
op.Not(op.Equal(op.Sign(self), op.Sign(other))),
3708+
op.Cast(op.Mod(self, other), to=BOOL.dtype),
3709+
)
3710+
offset = op.Cast(offset, to=self.dtype)
3711+
return op.Sub(op.Div(self, other), offset)
37083712

37093713

37103714
@torch_op("_operator::floordiv", trace_only=True)

onnxscript/rewriter/ort_fusions/sdpa.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ def pattern(
8888
)
8989

9090
attn_weight = op.Softmax(attn_score, axis=-1)
91+
is_nan = op.IsNaN(attn_weight)
92+
adj_attn_weight = op.Where(is_nan, 0.0, attn_weight)
93+
attn_weight = pattern.OrValue([adj_attn_weight, attn_weight])
9194
attn_output = op.MatMul(attn_weight, value)
9295
return attn_output
9396

onnxscript/rewriter/ort_fusions/sdpa_test.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def _unmasked_pre_div_sdpa_script(query, key, value):
4444
scaled_key = op.Div(key_transposed, divisor)
4545
attn_score = op.MatMul(scaled_query, scaled_key)
4646
attn_weight = op.Softmax(attn_score, axis=-1)
47-
attn_output = op.MatMul(attn_weight, value)
47+
is_nan = op.IsNaN(attn_weight)
48+
zero = op.Constant(value_float=0.0)
49+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
50+
attn_output = op.MatMul(adj_attn_weight, value)
4851
return attn_output
4952

5053

@@ -56,7 +59,10 @@ def _unmasked_pre_mul_sdpa_script(query, key, value):
5659
scaled_key = op.Mul(key_transposed, multiplier)
5760
attn_score = op.MatMul(scaled_query, scaled_key)
5861
attn_weight = op.Softmax(attn_score, axis=-1)
59-
attn_output = op.MatMul(attn_weight, value)
62+
is_nan = op.IsNaN(attn_weight)
63+
zero = op.Constant(value_float=0.0)
64+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
65+
attn_output = op.MatMul(adj_attn_weight, value)
6066
return attn_output
6167

6268

@@ -67,7 +73,10 @@ def _unmasked_post_div_sdpa_script(query, key, value):
6773
attn_score = op.MatMul(query, key_transposed)
6874
scaled_attn_score = op.Div(attn_score, divisor)
6975
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
70-
attn_output = op.MatMul(attn_weight, value)
76+
is_nan = op.IsNaN(attn_weight)
77+
zero = op.Constant(value_float=0.0)
78+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
79+
attn_output = op.MatMul(adj_attn_weight, value)
7180
return attn_output
7281

7382

@@ -78,7 +87,10 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7887
attn_score = op.MatMul(query, key_transposed)
7988
scaled_attn_score = op.Mul(attn_score, multiplier)
8089
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
81-
attn_output = op.MatMul(attn_weight, value)
90+
is_nan = op.IsNaN(attn_weight)
91+
zero = op.Constant(value_float=0.0)
92+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
93+
attn_output = op.MatMul(adj_attn_weight, value)
8294
return attn_output
8395

8496

@@ -90,7 +102,10 @@ def _custom_scale_pre_div_sdpa_script(query, key, value):
90102
scaled_key = op.Div(key_transposed, divisor)
91103
attn_score = op.MatMul(scaled_query, scaled_key)
92104
attn_weight = op.Softmax(attn_score, axis=-1)
93-
attn_output = op.MatMul(attn_weight, value)
105+
is_nan = op.IsNaN(attn_weight)
106+
zero = op.Constant(value_float=0.0)
107+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
108+
attn_output = op.MatMul(adj_attn_weight, value)
94109
return attn_output
95110

96111

@@ -102,7 +117,10 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
102117
scaled_key = op.Mul(key_transposed, multiplier)
103118
attn_score = op.MatMul(scaled_query, scaled_key)
104119
attn_weight = op.Softmax(attn_score, axis=-1)
105-
attn_output = op.MatMul(attn_weight, value)
120+
is_nan = op.IsNaN(attn_weight)
121+
zero = op.Constant(value_float=0.0)
122+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
123+
attn_output = op.MatMul(adj_attn_weight, value)
106124
return attn_output
107125

108126

@@ -115,7 +133,10 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
115133
scaled_key = op.Mul(key_transposed, multiplier_k)
116134
attn_score = op.MatMul(scaled_query, scaled_key)
117135
attn_weight = op.Softmax(attn_score, axis=-1)
118-
attn_output = op.MatMul(attn_weight, value)
136+
is_nan = op.IsNaN(attn_weight)
137+
zero = op.Constant(value_float=0.0)
138+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
139+
attn_output = op.MatMul(adj_attn_weight, value)
119140
return attn_output
120141

121142

@@ -126,7 +147,10 @@ def _custom_scale_post_div_sdpa_script(query, key, value):
126147
attn_score = op.MatMul(query, key_transposed)
127148
scaled_attn_score = op.Div(attn_score, divisor)
128149
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
129-
attn_output = op.MatMul(attn_weight, value)
150+
is_nan = op.IsNaN(attn_weight)
151+
zero = op.Constant(value_float=0.0)
152+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
153+
attn_output = op.MatMul(adj_attn_weight, value)
130154
return attn_output
131155

132156

@@ -137,7 +161,10 @@ def _custom_scale_post_mul_sdpa_script(query, key, value):
137161
attn_score = op.MatMul(query, key_transposed)
138162
scaled_attn_score = op.Mul(attn_score, multiplier)
139163
attn_weight = op.Softmax(scaled_attn_score, axis=-1)
140-
attn_output = op.MatMul(attn_weight, value)
164+
is_nan = op.IsNaN(attn_weight)
165+
zero = op.Constant(value_float=0.0)
166+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
167+
attn_output = op.MatMul(adj_attn_weight, value)
141168
return attn_output
142169

143170

@@ -150,7 +177,10 @@ def _masked_pre_div_sdpa_script(query, key, value, mask):
150177
attn_score = op.MatMul(scaled_query, scaled_key)
151178
masked_attn_score = op.Add(attn_score, mask)
152179
attn_weight = op.Softmax(masked_attn_score, axis=-1)
153-
attn_output = op.MatMul(attn_weight, value)
180+
is_nan = op.IsNaN(attn_weight)
181+
zero = op.Constant(value_float=0.0)
182+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
183+
attn_output = op.MatMul(adj_attn_weight, value)
154184
return attn_output
155185

156186

@@ -163,7 +193,10 @@ def _masked_pre_mul_sdpa_script(query, key, value, mask):
163193
attn_score = op.MatMul(scaled_query, scaled_key)
164194
masked_attn_score = op.Add(attn_score, mask)
165195
attn_weight = op.Softmax(masked_attn_score, axis=-1)
166-
attn_output = op.MatMul(attn_weight, value)
196+
is_nan = op.IsNaN(attn_weight)
197+
zero = op.Constant(value_float=0.0)
198+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
199+
attn_output = op.MatMul(adj_attn_weight, value)
167200
return attn_output
168201

169202

@@ -175,7 +208,10 @@ def _masked_post_div_sdpa_script(query, key, value, mask):
175208
scaled_attn_score = op.Div(attn_score, divisor)
176209
masked_attn_score = op.Add(scaled_attn_score, mask)
177210
attn_weight = op.Softmax(masked_attn_score, axis=-1)
178-
attn_output = op.MatMul(attn_weight, value)
211+
is_nan = op.IsNaN(attn_weight)
212+
zero = op.Constant(value_float=0.0)
213+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
214+
attn_output = op.MatMul(adj_attn_weight, value)
179215
return attn_output
180216

181217

@@ -187,7 +223,10 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
187223
scaled_attn_score = op.Mul(attn_score, multiplier)
188224
masked_attn_score = op.Add(scaled_attn_score, mask)
189225
attn_weight = op.Softmax(masked_attn_score, axis=-1)
190-
attn_output = op.MatMul(attn_weight, value)
226+
is_nan = op.IsNaN(attn_weight)
227+
zero = op.Constant(value_float=0.0)
228+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
229+
attn_output = op.MatMul(adj_attn_weight, value)
191230
return attn_output
192231

193232

@@ -200,7 +239,10 @@ def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask):
200239
attn_score = op.MatMul(scaled_query, scaled_key)
201240
masked_attn_score = op.Add(attn_score, mask)
202241
attn_weight = op.Softmax(masked_attn_score, axis=-1)
203-
attn_output = op.MatMul(attn_weight, value)
242+
is_nan = op.IsNaN(attn_weight)
243+
zero = op.Constant(value_float=0.0)
244+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
245+
attn_output = op.MatMul(adj_attn_weight, value)
204246
return attn_output
205247

206248

@@ -213,7 +255,10 @@ def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask):
213255
attn_score = op.MatMul(scaled_query, scaled_key)
214256
masked_attn_score = op.Add(attn_score, mask)
215257
attn_weight = op.Softmax(masked_attn_score, axis=-1)
216-
attn_output = op.MatMul(attn_weight, value)
258+
is_nan = op.IsNaN(attn_weight)
259+
zero = op.Constant(value_float=0.0)
260+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
261+
attn_output = op.MatMul(adj_attn_weight, value)
217262
return attn_output
218263

219264

@@ -225,7 +270,10 @@ def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask):
225270
scaled_attn_score = op.Div(attn_score, divisor)
226271
masked_attn_score = op.Add(scaled_attn_score, mask)
227272
attn_weight = op.Softmax(masked_attn_score, axis=-1)
228-
attn_output = op.MatMul(attn_weight, value)
273+
is_nan = op.IsNaN(attn_weight)
274+
zero = op.Constant(value_float=0.0)
275+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
276+
attn_output = op.MatMul(adj_attn_weight, value)
229277
return attn_output
230278

231279

@@ -237,7 +285,10 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
237285
scaled_attn_score = op.Mul(attn_score, multiplier)
238286
masked_attn_score = op.Add(scaled_attn_score, mask)
239287
attn_weight = op.Softmax(masked_attn_score, axis=-1)
240-
attn_output = op.MatMul(attn_weight, value)
288+
is_nan = op.IsNaN(attn_weight)
289+
zero = op.Constant(value_float=0.0)
290+
adj_attn_weight = op.Where(is_nan, zero, attn_weight)
291+
attn_output = op.MatMul(adj_attn_weight, value)
241292
return attn_output
242293

243294

onnxscript/rewriter/rules/common/_fuse_conv_affine_test.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ def clone_model(self, model: ir.Model) -> ir.Model:
1818

1919
def test_conv_affine_fusion(self):
2020
tape = ir.tape.Tape()
21-
x = ir.Input(
22-
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
23-
)
21+
x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32]))
2422
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
2523
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
2624
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
@@ -31,10 +29,10 @@ def test_conv_affine_fusion(self):
3129
z = tape.op(
3230
"Add",
3331
[mul_out, offset],
34-
output=ir.Input(
32+
output=ir.val(
3533
"z",
34+
dtype=ir.DataType.FLOAT,
3635
shape=ir.Shape([1, 3, 32, 32]),
37-
type=ir.TensorType(ir.DataType.FLOAT),
3836
),
3937
)
4038

@@ -65,9 +63,7 @@ def test_conv_affine_fusion(self):
6563

6664
def test_affine_conv_fusion_without_pad(self):
6765
tape = ir.tape.Tape()
68-
x = ir.Input(
69-
"x", shape=ir.Shape([1, 3, 32, 32]), type=ir.TensorType(ir.DataType.FLOAT)
70-
)
66+
x = ir.val("x", dtype=ir.DataType.FLOAT, shape=ir.Shape([1, 3, 32, 32]))
7167
w = tape.initializer(ir.tensor(np.ones((3, 3, 3, 3), dtype=np.float32), name="w"))
7268
b = tape.initializer(ir.tensor(np.ones((3,), dtype=np.float32), name="b"))
7369
scale = tape.initializer(ir.tensor(np.array([2.0], dtype=np.float32), name="scale"))
@@ -77,10 +73,10 @@ def test_affine_conv_fusion_without_pad(self):
7773
z = tape.op(
7874
"Add",
7975
[mul_out, offset],
80-
output=ir.Input(
76+
output=ir.val(
8177
"z",
78+
dtype=ir.DataType.FLOAT,
8279
shape=ir.Shape([1, 3, 32, 32]),
83-
type=ir.TensorType(ir.DataType.FLOAT),
8480
),
8581
)
8682
conv_out = tape.op("Conv", [z, w, b], attributes={"pads": [0, 0, 0, 0]})

tests/function_libs/torch_lib/extra_opinfo.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,18 +2270,9 @@ def __init__(self):
22702270
opinfo_core.BinaryUfuncInfo(
22712271
"ops.aten.floor_divide",
22722272
aten_name="floor_divide",
2273-
dtypes=common_dtype.floating_types_and_half(),
2273+
dtypes=common_dtype.all_types_and_half(),
22742274
rhs_make_tensor_kwargs=dict(exclude_zero=True),
22752275
),
2276-
opinfo_core.BinaryUfuncInfo(
2277-
"ops.aten.floor_divide.int",
2278-
aten_name="floor_divide",
2279-
op=torch.ops.aten.floor_divide,
2280-
dtypes=common_dtype.integral_types(),
2281-
# Create only positive inputs
2282-
lhs_make_tensor_kwargs=dict(low=0),
2283-
rhs_make_tensor_kwargs=dict(exclude_zero=True, low=0),
2284-
),
22852276
opinfo_core.OpInfo(
22862277
"ops.aten.hamming_window",
22872278
aten_name="hamming_window",

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -794,7 +794,6 @@ def _where_input_wrangler(
794794
TorchLibOpInfo("flatten", core_ops.aten_flatten),
795795
TorchLibOpInfo("floor", core_ops.aten_floor),
796796
TorchLibOpInfo("ops.aten.floor_divide", core_ops.aten_floor_divide),
797-
TorchLibOpInfo("ops.aten.floor_divide.int", core_ops.aten_floor_divide_int),
798797
TorchLibOpInfo("fmod", core_ops.aten_fmod),
799798
TorchLibOpInfo("frac", core_ops.aten_frac),
800799
TorchLibOpInfo("full", core_ops.aten_full),

0 commit comments

Comments
 (0)