Skip to content

Commit 24ffead

Browse files
committed
[tuner] update the calculation of shared memory usage
Signed-off-by: Bangtian Liu <[email protected]>
1 parent 3e8e700 commit 24ffead

File tree

4 files changed

+51
-12
lines changed

4 files changed

+51
-12
lines changed

amdsharktuner/amdsharktuner/constraint_generator.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -244,9 +244,7 @@ def set_cdim_tile_sizes(tile_sizes, contraction_dims, csizes):
244244
promote_operands = [0, 1]
245245
padding = None
246246
if required_padding:
247-
# TODO: Remove promotion of operand 2 once codegen supports handling padded outputs without promotion.
248-
promote_operands = [0, 1, 2]
249-
_, _, mma_intrinsic_k = mma_attr.mnk_shape
247+
mma_intrinsic_k = mma_attr.mnk_shape[2]
250248
padding = [
251249
*(workgroup_tile_sizes[d] for d in contraction_dims.m),
252250
*(workgroup_tile_sizes[d] for d in contraction_dims.n),

amdsharktuner/amdsharktuner/dispatch_constraints.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,17 +161,37 @@ def get_dispatch_constraints(
161161
def calculate_shared_memory_usage_in_bytes(
162162
lhs_type: common.ShapedType,
163163
rhs_type: common.ShapedType,
164+
res_type: common.ShapedType,
164165
m: list[int] | list[z3.ArithRef],
165166
n: list[int] | list[z3.ArithRef],
166167
k: list[int] | list[z3.ArithRef],
168+
promote_operands: list[int] = [0, 1],
167169
) -> int | z3.ArithRef:
170+
assert promote_operands == [0, 1] or promote_operands == [
171+
0,
172+
1,
173+
2,
174+
], f"Got {promote_operands}"
175+
168176
lhs_memory = lhs_type.bitwidth // 8
169177
for size in m + k:
170178
lhs_memory *= size
179+
171180
rhs_memory = rhs_type.bitwidth // 8
172181
for size in n + k:
173182
rhs_memory *= size
174-
return lhs_memory + rhs_memory
183+
184+
output_memory = res_type.bitwidth // 8
185+
for size in m + n:
186+
output_memory *= size
187+
188+
total_memory = (
189+
int(0 in promote_operands) * lhs_memory
190+
+ int(1 in promote_operands) * rhs_memory
191+
+ int(2 in promote_operands) * output_memory
192+
)
193+
194+
return total_memory
175195

176196

177197
def generate_vector_distribute_constraints(
@@ -258,7 +278,7 @@ def generate_vector_distribute_constraints(
258278
constraints += [subgroups >= 1, subgroups <= 10]
259279

260280
shared_memory = calculate_shared_memory_usage_in_bytes(
261-
lhs_type, rhs_type, [m], [n], [k]
281+
lhs_type, rhs_type, res_type, [m], [n], [k]
262282
)
263283
constraints += [shared_memory <= gpu_target_info.max_workgroup_memory_bytes]
264284

@@ -360,7 +380,7 @@ def generate_tile_and_fuse_constraints(
360380
constraints += [wg_threads == subgroups * subgroup_size]
361381

362382
shared_memory = calculate_shared_memory_usage_in_bytes(
363-
lhs_type, rhs_type, m_tiles, n_tiles, k_tiles
383+
lhs_type, rhs_type, res_type, m_tiles, n_tiles, k_tiles
364384
)
365385
constraints += [
366386
shared_memory * intrinsic_k <= gpu_target_info.max_workgroup_memory_bytes

amdsharktuner/tests/constraint_generator_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def test_generate_solutions_tile_and_fuse_contraction_padding(
296296
lowering_config
297297
), f"Missing padding in lowering config: {lowering_config}"
298298
promote = [int(x) for x in lowering_config.attributes["promote_operands"]]
299-
assert promote == [0, 1, 2]
299+
assert promote == [0, 1]
300300

301301

302302
def test_generate_solutions_tile_and_fuse_conv_padding(
@@ -373,7 +373,7 @@ def test_generate_solutions_tile_and_fuse_conv_padding(
373373
lowering_config
374374
), f"Missing padding in lowering config: {lowering_config}"
375375
promote = [int(x) for x in lowering_config.attributes["promote_operands"]]
376-
assert promote == [0, 1, 2]
376+
assert promote == [0, 1]
377377

378378

379379
def test_adjust_problem_size_for_pipeline(

amdsharktuner/tests/dispatch_constraints_test.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,36 +43,57 @@ def gpu_target_info(tuner_ctx: common.TunerContext) -> iree_gpu.TargetInfo:
4343
def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) -> None:
4444
lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16)
4545
rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16)
46+
res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32)
4647
assert (
4748
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
48-
rhs_type, rhs_type, [512], [64], [128]
49+
lhs_type, rhs_type, res_type, [512], [64], [128]
4950
)
5051
== 147456
5152
)
5253

5354
lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i8)
5455
assert (
5556
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
56-
lhs_type, rhs_type, [512], [64], [128]
57+
lhs_type, rhs_type, res_type, [512], [64], [128]
5758
)
5859
== 81920
5960
)
6061

6162
rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i32)
6263
assert (
6364
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
64-
lhs_type, rhs_type, [128], [64], [32]
65+
lhs_type, rhs_type, res_type, [128], [64], [32]
6566
)
6667
== 12288
6768
)
6869

6970
assert (
7071
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
71-
lhs_type, rhs_type, [2, 64], [4, 16], [8, 4]
72+
lhs_type, rhs_type, res_type, [2, 64], [4, 16], [8, 4]
7273
)
7374
== 12288
7475
)
7576

77+
lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16)
78+
rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16)
79+
res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32)
80+
assert (
81+
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
82+
lhs_type, rhs_type, res_type, [512], [64], [128], promote_operands=[0, 1, 2]
83+
)
84+
== 278528
85+
)
86+
87+
with pytest.raises(AssertionError):
88+
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
89+
lhs_type, rhs_type, res_type, [512], [64], [128], promote_operands=[0]
90+
)
91+
92+
with pytest.raises(AssertionError):
93+
dispatch_constraints.calculate_shared_memory_usage_in_bytes(
94+
lhs_type, rhs_type, res_type, [512], [64], [128], promote_operands=[1, 2]
95+
)
96+
7697

7798
def test_generate_tile_and_fuse_constraints_valid_input(
7899
tuner_ctx: common.TunerContext, gpu_target_info: iree_gpu.TargetInfo

0 commit comments

Comments
 (0)