Skip to content

Commit db10f6e

Browse files
raikonenfnuGroverkss
authored andcommitted
[mlperf][pkgci] Update punet-fp8 with reduction dim as last dim (iree-org#19316)
We have changes in sharktank that converts reduction dim of the custom attention to be the fastest dimension. This makes it more uniform with the FP16 and canonical attention form and hopefully makes optimization gets called more easily down the line with this. Additionally, this is to prefetch S.T we do not break the coming sharktank/mlperf bots and runs. Signed-off-by: Stanley Winata <[email protected]>
1 parent 6a52491 commit db10f6e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

build_tools/pkgci/external_test_suite/attention_and_matmul_spec_punet.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ transform.named_sequence @match_attention_f8(%attention: !transform.any_op {tran
7676
transform.iree.match.cast_compatible_type %in0 = tensor<?x?x?x?xf8E4M3FNUZ> : !transform.any_value
7777

7878
%config = transform.param.constant #iree_codegen.compilation_info<
79-
lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 64, 0], promote_operands = [1, 2]}>,
79+
lowering_config = #iree_gpu.lowering_config<{workgroup = [1, 1, 64, 0, 0, 0], reduction=[0, 0, 0, 0, 0, 64], promote_operands = [1, 2]}>,
8080
translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
8181
workgroup_size = [64, 4]
8282
subgroup_size = 64 ,

experimental/regression_suite/shark-test-suite-models/sdxl/test_unet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@
121121
)
122122

123123
sdxl_punet_int8_fp8_mlir = fetch_source_fixture(
124-
"https://sharkpublic.blob.core.windows.net/sharkpublic/sai/sdxl-punet/11-8-2024/punet_fp8.mlir",
124+
"https://sharkpublic.blob.core.windows.net/sharkpublic/stan/sdxl-punet/11-26-2024/punet_fp8.mlir",
125125
group="sdxl_punet_int8_fp8",
126126
)
127127

0 commit comments

Comments
 (0)