Skip to content

Commit b7ab952

Browse files
committed
Fix unsqueeze axes dims in test case
Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent b5c0128 commit b7ab952

File tree

4 files changed

+19
-19
lines changed

4 files changed

+19
-19
lines changed

onnxscript/rewriter/models/_rotary_embedding_models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def _test_case_1_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[1, 8]) -> FLOA
2626
emb = op.Concat(freqs, freqs, axis=-1)
2727
cos = op.Cos(emb)
2828
sin = op.Sin(emb)
29-
cos_4d = op.Unsqueeze(cos, 1)
30-
sin_4d = op.Unsqueeze(sin, 1)
29+
cos_4d = op.Unsqueeze(cos, [1])
30+
sin_4d = op.Unsqueeze(sin, [1])
3131

3232
x1 = op.Slice(x, [0], [4], [3], [1])
3333
x2 = op.Slice(x, [4], [8], [3], [1])
@@ -73,8 +73,8 @@ def _test_case_2_script(x: FLOAT[1, 4, 8, 8], position_ids: INT64[8]) -> FLOAT[1
7373
emb = op.Concat(freqs, freqs, axis=-1)
7474
cos = op.Cos(emb)
7575
sin = op.Sin(emb)
76-
cos_4d = op.Unsqueeze(cos, 1)
77-
sin_4d = op.Unsqueeze(sin, 1)
76+
cos_4d = op.Unsqueeze(cos, [1])
77+
sin_4d = op.Unsqueeze(sin, [1])
7878

7979
x1 = op.Slice(x, [0], [4], [3], [1])
8080
x2 = op.Slice(x, [4], [8], [3], [1])
@@ -127,8 +127,8 @@ def _partial_rotary_script(position_ids, query):
127127
# Split the query for partial embedding
128128
to_embed = op.Slice(query, [0], [32], [3], [1])
129129
unembedded = op.Slice(query, [32], [9223372036854775807], [3], [1])
130-
cos_4d = op.Unsqueeze(cos_3d, 1) # [B, 1, S, rd]
131-
sin_4d = op.Unsqueeze(sin_3d, 1) # [B, 1, S, rd]
130+
cos_4d = op.Unsqueeze(cos_3d, [1]) # [B, 1, S, rd]
131+
sin_4d = op.Unsqueeze(sin_3d, [1]) # [B, 1, S, rd]
132132
# Compute rotation of X as X * cos + rotate_half(X) * sin, where rotate_half(X)
133133
# essentially represents X rotated by 90 degrees
134134
to_embed_times_cos = op.Mul(to_embed, cos_4d)

onnxscript/rewriter/models/_smollm_1.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def main_graph(
5959
minus_inf_10x10 = opset18.ConstantOfShape([10, 10], [-3.4028234663852886e38])
6060
mask_10x10 = opset18.Trilu(minus_inf_10x10, 1)
6161
slice_5 = opset18.Reshape(mask_10x10, [1, 1, 10, 10])
62-
unsqueeze_2 = opset18.Unsqueeze(input1, 1)
63-
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, 2)
62+
unsqueeze_2 = opset18.Unsqueeze(input1, [1])
63+
unsqueeze_3 = opset18.Unsqueeze(unsqueeze_2, [2])
6464
add = slice_5 + unsqueeze_3
6565
eq = add == 0.0
6666
slice_10 = slice_5
@@ -69,7 +69,7 @@ def main_graph(
6969
slice_scatter = opset18.Transpose(val_179, perm=[2, 1, 0, 3])
7070
val_191 = opset18.Transpose(slice_scatter, perm=[1, 0, 2, 3])
7171
slice_scatter_1 = opset18.Transpose(val_191, perm=[1, 0, 2, 3])
72-
unsqueeze_6 = opset18.Unsqueeze(input2, 1)
72+
unsqueeze_6 = opset18.Unsqueeze(input2, [1])
7373
to_copy_1 = opset18.Cast(unsqueeze_6, to=1)
7474
view_1 = opset18.Constant(
7575
value=ir.tensor(
@@ -138,8 +138,8 @@ def main_graph(
138138
transpose_2 = opset18.Transpose(view_11, perm=[0, 2, 1, 3])
139139
view_12 = opset18.Reshape(view_9, [1, 10, 32, 64], allowzero=0)
140140
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
141-
unsqueeze_7 = opset18.Unsqueeze(cos, 1)
142-
unsqueeze_8 = opset18.Unsqueeze(sin, 1)
141+
unsqueeze_7 = opset18.Unsqueeze(cos, [1])
142+
unsqueeze_8 = opset18.Unsqueeze(sin, [1])
143143
mul_5 = transpose_1 * unsqueeze_7
144144
val_267 = opset18.Constant(value_ints=[1])
145145
slice_19 = opset18.Slice(transpose_1, [0], [32], [3], val_267)

onnxscript/rewriter/models/_smollm_2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main_graph(
5151
gt = arange_1 > view
5252
convert_element_type_default = opset18.Cast(gt, to=1)
5353
mul = triu * convert_element_type_default
54-
dim__2 = opset18.Constant(value_int=0)
54+
dim__2 = opset18.Constant(value_ints=[0])
5555
dim_0__2 = opset18.Cast(dim__2, to=7)
5656
unsqueeze = opset18.Unsqueeze(model_rotary_emb_inv_freq, dim_0__2)
5757
val_15 = opset18.Cast(0, to=7)
@@ -65,7 +65,7 @@ def main_graph(
6565
val_25 = opset18.Reshape(val_23, val_24, allowzero=0)
6666
val_26 = opset18.Constant(value_ints=[1])
6767
slice_1 = opset18.Slice(unsqueeze, val_17, val_21, val_25, val_26)
68-
dim__3 = opset18.Constant(value_int=2)
68+
dim__3 = opset18.Constant(value_ints=[2])
6969
dim_0__3 = opset18.Cast(dim__3, to=7)
7070
unsqueeze_1 = opset18.Unsqueeze(slice_1, dim_0__3)
7171
_to_copy = opset18.Cast(unsqueeze_1, to=1)
@@ -83,7 +83,7 @@ def main_graph(
8383
val_36 = opset18.Reshape(val_34, val_35, allowzero=0)
8484
val_37 = opset18.Constant(value_ints=[1])
8585
slice_2 = opset18.Slice(position_ids, val_30, val_33, val_36, val_37)
86-
dim__5 = opset18.Constant(value_int=1)
86+
dim__5 = opset18.Constant(value_ints=[1])
8787
dim_0__5 = opset18.Cast(dim__5, to=7)
8888
unsqueeze_2 = opset18.Unsqueeze(slice_2, dim_0__5)
8989
val_38 = opset18.Cast(0, to=7)
@@ -160,10 +160,10 @@ def main_graph(
160160
val_71 = opset18.Cast([1, 30, 32, 64], to=7)
161161
view_12 = opset18.Reshape(view_9, val_71, allowzero=0)
162162
transpose_3 = opset18.Transpose(view_12, perm=[0, 2, 1, 3])
163-
dim__8 = opset18.Constant(value_int=1)
163+
dim__8 = opset18.Constant(value_ints=[1])
164164
dim_0__8 = opset18.Cast(dim__8, to=7)
165165
unsqueeze_3 = opset18.Unsqueeze(_to_copy_4, dim_0__8)
166-
dim__9 = opset18.Constant(value_int=1)
166+
dim__9 = opset18.Constant(value_ints=[1])
167167
dim_0__9 = opset18.Cast(dim__9, to=7)
168168
unsqueeze_4 = opset18.Unsqueeze(_to_copy_5, dim_0__9)
169169
mul_5 = transpose_1 * unsqueeze_3
@@ -222,10 +222,10 @@ def main_graph(
222222
add_2 = mul_7 + mul_8
223223
cat_3 = opset18.Concat(past_key_values_0_0, add_2, axis=-2)
224224
cat_4 = opset18.Concat(past_key_values_0_1, transpose_3, axis=-2)
225-
dim__10 = opset18.Constant(value_int=0)
225+
dim__10 = opset18.Constant(value_ints=[0])
226226
dim_0__10 = opset18.Cast(dim__10, to=7)
227227
unsqueeze_5 = opset18.Unsqueeze(mul, dim_0__10)
228-
dim__11 = opset18.Constant(value_int=1)
228+
dim__11 = opset18.Constant(value_ints=[1])
229229
dim_0__11 = opset18.Cast(dim__11, to=7)
230230
unsqueeze_6 = opset18.Unsqueeze(unsqueeze_5, dim_0__11)
231231
val_114 = opset18.Cast(0, to=7)

onnxscript/rewriter/ort_fusions/cos_sin_cache_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_cos_sin_fusion(self, name, test_data_constructor):
4545
original_outputs = ort_run("original", model, inputs)
4646
count = fuse_rotary_embedding(model)
4747
self.assertGreater(count, 0)
48-
count = fuse_cos_sin_cache(model)
48+
count = fuse_cos_sin_cache(model, debug=True)
4949
self.assertGreater(count, 0)
5050
new_outputs = ort_run("optimized", model, inputs)
5151
assert_allclose(new_outputs, original_outputs)

0 commit comments

Comments
 (0)