Skip to content

Commit 46c97f9

Browse files
authored
[ROCMTarget] Make all pingpong arithmetic nsw and nuw (iree-org#21248)
This is safe because all arithmetic is either thread id related arithmetic which is known to be well within bounds given the required workgroup size or is pointer arithmetic where wrapping would be ub.
1 parent 6f6e577 commit 46c97f9

File tree

1 file changed

+63
-63
lines changed

1 file changed

+63
-63
lines changed

compiler/plugins/target/ROCM/builtins/tuning/iree_default_tuning_spec_gfx942.mlir

Lines changed: 63 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,14 @@ util.func private @pingpong_large(%lhs_base: !in_ty, %rhs_base: !in_ty, %unused_
7878

7979
scf.forall (%id) in (2048) {
8080
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
81-
%vec = arith.muli %delin#1, %c8 : index
81+
%vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
8282
%lhs_thread_local = tensor.extract_slice %lhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
8383
%lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
8484
vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
8585
} {mapping = [#gpu.thread<linear_dim_0>]}
8686
scf.forall (%id) in (2048) {
8787
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
88-
%vec = arith.muli %delin#1, %c8 : index
88+
%vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
8989
%rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
9090
%rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
9191
vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
@@ -97,21 +97,21 @@ util.func private @pingpong_large(%lhs_base: !in_ty, %rhs_base: !in_ty, %unused_
9797
%0 = tensor.empty() : tensor<16x16x16x16xf32>
9898
%1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<16x16x16x16xf32> {
9999
%ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
100-
%inner_id = arith.muli %ids#2, %c4 : index
101-
%m_outer_id = arith.muli %ids#0, %c8 : index
102-
%n_outer_id = arith.muli %ids#1, %c4 : index
100+
%inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
101+
%m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
102+
%n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
103103
%delin:2 = affine.delinearize_index %id into (64, 8) : index, index
104104
%wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
105105

106106
// Inner 64 loads 8 threads x 8 elements.
107-
%gko = arith.muli %wt#2, %c8 : index
107+
%gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
108108
// Each subgroup loads 32 contiguous rows out of 256.
109-
%bpo = arith.muli %wt#0, %c32 : index
109+
%bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
110110
// Base index is remaining outer 8 lanes + subgroup base.
111-
%glb0 = arith.addi %wt#1, %bpo : index
112-
%glb1 = arith.addi %glb0, %c8 : index
113-
%glb2 = arith.addi %glb1, %c8 : index
114-
%glb3 = arith.addi %glb2, %c8 : index
111+
%glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
112+
%glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
113+
%glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
114+
%glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
115115

116116
%2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
117117

@@ -299,14 +299,14 @@ util.func private @pingpong_large_expanded(%lhs_base: !exp_in_ty, %rhs_base: !in
299299

300300
scf.forall (%id) in (2048) {
301301
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
302-
%vec = arith.muli %delin#1, %c8 : index
302+
%vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
303303
%lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !exp_block_in to tensor<1x1x8xf16>
304304
%lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
305305
vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
306306
} {mapping = [#gpu.thread<linear_dim_0>]}
307307
scf.forall (%id) in (2048) {
308308
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
309-
%vec = arith.muli %delin#1, %c8 : index
309+
%vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
310310
%rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
311311
%rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
312312
vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
@@ -318,21 +318,21 @@ util.func private @pingpong_large_expanded(%lhs_base: !exp_in_ty, %rhs_base: !in
318318
%0 = tensor.empty() : tensor<1x16x16x16x16xf32>
319319
%1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
320320
%ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
321-
%inner_id = arith.muli %ids#2, %c4 : index
322-
%m_outer_id = arith.muli %ids#0, %c8 : index
323-
%n_outer_id = arith.muli %ids#1, %c4 : index
321+
%inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
322+
%m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
323+
%n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
324324
%delin:2 = affine.delinearize_index %id into (64, 8) : index, index
325325
%wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
326326

327327
// Inner 64 loads 8 threads x 8 elements.
328-
%gko = arith.muli %wt#2, %c8 : index
328+
%gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
329329
// Each subgroup loads 32 contiguous rows out of 256.
330-
%bpo = arith.muli %wt#0, %c32 : index
330+
%bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
331331
// Base index is remaining outer 8 lanes + subgroup base.
332-
%glb0 = arith.addi %wt#1, %bpo : index
333-
%glb1 = arith.addi %glb0, %c8 : index
334-
%glb2 = arith.addi %glb1, %c8 : index
335-
%glb3 = arith.addi %glb2, %c8 : index
332+
%glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
333+
%glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
334+
%glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
335+
%glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
336336

337337
%2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
338338

@@ -525,14 +525,14 @@ util.func private @pingpong_large_f8_expanded(%lhs_base: !exp_in_ty_f8, %rhs_bas
525525

526526
scf.forall (%id) in (2048) {
527527
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
528-
%vec = arith.muli %delin#1, %c16 : index
528+
%vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
529529
%lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 16] [1, 1, 1] : !exp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
530530
%lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
531531
vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
532532
} {mapping = [#gpu.thread<linear_dim_0>]}
533533
scf.forall (%id) in (2048) {
534534
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
535-
%vec = arith.muli %delin#1, %c16 : index
535+
%vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
536536
%rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
537537
%rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
538538
vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
@@ -544,22 +544,22 @@ util.func private @pingpong_large_f8_expanded(%lhs_base: !exp_in_ty_f8, %rhs_bas
544544
%0 = tensor.empty() : tensor<1x16x16x16x16xf32>
545545
%1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x16x16x16x16xf32> {
546546
%ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
547-
%inner_id = arith.muli %ids#2, %c8 : index
548-
%inner_id_acc = arith.muli %ids#2, %c4 : index
549-
%m_outer_id = arith.muli %ids#0, %c8 : index
550-
%n_outer_id = arith.muli %ids#1, %c4 : index
547+
%inner_id = arith.muli %ids#2, %c8 overflow<nsw, nuw> : index
548+
%inner_id_acc = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
549+
%m_outer_id = arith.muli %ids#0, %c8 overflow<nsw, nuw> : index
550+
%n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
551551
%delin:2 = affine.delinearize_index %id into (64, 8) : index, index
552552
%wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
553553

554554
// Inner 64 loads 8 threads x 16 elements.
555-
%gko = arith.muli %wt#2, %c16 : index
555+
%gko = arith.muli %wt#2, %c16 overflow<nsw, nuw> : index
556556
// Each subgroup loads 32 contiguous rows out of 256.
557-
%bpo = arith.muli %wt#0, %c32 : index
557+
%bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
558558
// Base index is remaining outer 8 lanes + subgroup base.
559-
%glb0 = arith.addi %wt#1, %bpo : index
560-
%glb1 = arith.addi %glb0, %c8 : index
561-
%glb2 = arith.addi %glb1, %c8 : index
562-
%glb3 = arith.addi %glb2, %c8 : index
559+
%glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
560+
%glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
561+
%glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
562+
%glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
563563

564564
%2 = arith.constant dense<0.0> : vector<8x4x1x4xf32>
565565

@@ -751,14 +751,14 @@ util.func private @pingpong_medium_expanded(%lhs_base: !mexp_in_ty, %rhs_base: !
751751

752752
scf.forall (%id) in (1024) {
753753
%delin:2 = affine.delinearize_index %id into (128, 8) : index, index
754-
%vec = arith.muli %delin#1, %c8 : index
754+
%vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
755755
%lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 8] [1, 1, 1] : !mexp_block_in to tensor<1x1x8xf16>
756756
%lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x8xf16>, vector<1x8xf16>
757757
vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !mshared
758758
} {mapping = [#gpu.thread<linear_dim_0>]}
759759
scf.forall (%id) in (2048) {
760760
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
761-
%vec = arith.muli %delin#1, %c8 : index
761+
%vec = arith.muli %delin#1, %c8 overflow<nsw, nuw> : index
762762
%rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 8] [1, 1] : !block_in to tensor<1x8xf16>
763763
%rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x8xf16>, vector<1x8xf16>
764764
vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x8xf16>, !shared
@@ -770,25 +770,25 @@ util.func private @pingpong_medium_expanded(%lhs_base: !mexp_in_ty, %rhs_base: !
770770
%0 = tensor.empty() : tensor<1x8x16x16x16xf32>
771771
%1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x8x16x16x16xf32> {
772772
%ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
773-
%inner_id = arith.muli %ids#2, %c4 : index
774-
%m_outer_id = arith.muli %ids#0, %c4 : index
775-
%n_outer_id = arith.muli %ids#1, %c4 : index
773+
%inner_id = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
774+
%m_outer_id = arith.muli %ids#0, %c4 overflow<nsw, nuw> : index
775+
%n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
776776
%delin:2 = affine.delinearize_index %id into (64, 8) : index, index
777777
%wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
778778

779779
// Inner 64 loads 8 threads x 8 elements.
780-
%gko = arith.muli %wt#2, %c8 : index
780+
%gko = arith.muli %wt#2, %c8 overflow<nsw, nuw> : index
781781
// RHS indexing. Each subgroup loads 32 contiguous rows out of 256.
782-
%bpo = arith.muli %wt#0, %c32 : index
782+
%bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
783783
// Base index is remaining outer 8 lanes + subgroup base.
784-
%glb0 = arith.addi %wt#1, %bpo : index
785-
%glb1 = arith.addi %glb0, %c8 : index
786-
%glb2 = arith.addi %glb1, %c8 : index
787-
%glb3 = arith.addi %glb2, %c8 : index
784+
%glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
785+
%glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
786+
%glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
787+
%glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
788788
// LHS indexing.
789-
%bpo_lhs = arith.muli %wt#0, %c16 : index
790-
%glb0_lhs = arith.addi %wt#1, %bpo_lhs : index
791-
%glb1_lhs = arith.addi %glb0_lhs, %c8 : index
789+
%bpo_lhs = arith.muli %wt#0, %c16 overflow<nsw, nuw> : index
790+
%glb0_lhs = arith.addi %wt#1, %bpo_lhs overflow<nsw, nuw> : index
791+
%glb1_lhs = arith.addi %glb0_lhs, %c8 overflow<nsw, nuw> : index
792792

793793
%2 = arith.constant dense<0.0> : vector<4x4x1x4xf32>
794794

@@ -944,14 +944,14 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
944944

945945
scf.forall (%id) in (1024) {
946946
%delin:2 = affine.delinearize_index %id into (128, 8) : index, index
947-
%vec = arith.muli %delin#1, %c16 : index
947+
%vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
948948
%lhs_thread_local = tensor.extract_slice %lhs_init [0, %delin#0, %vec] [1, 1, 16] [1, 1, 1] : !mexp_block_in_f8 to tensor<1x1x16xf8E4M3FNUZ>
949949
%lhs_vec_local = vector.transfer_read %lhs_thread_local [%c0, %c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
950950
vector.transfer_write %lhs_vec_local, %lhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !mshared_f8
951951
} {mapping = [#gpu.thread<linear_dim_0>]}
952952
scf.forall (%id) in (2048) {
953953
%delin:2 = affine.delinearize_index %id into (256, 8) : index, index
954-
%vec = arith.muli %delin#1, %c16 : index
954+
%vec = arith.muli %delin#1, %c16 overflow<nsw, nuw> : index
955955
%rhs_thread_local = tensor.extract_slice %rhs_init [%delin#0, %vec] [1, 16] [1, 1] : !block_in_f8 to tensor<1x16xf8E4M3FNUZ>
956956
%rhs_vec_local = vector.transfer_read %rhs_thread_local [%c0, %c0], %cst {in_bounds = [true, true]} : tensor<1x16xf8E4M3FNUZ>, vector<1x16xf8E4M3FNUZ>
957957
vector.transfer_write %rhs_vec_local, %rhs_shared[%delin#0, %vec] {in_bounds = [true, true]} : vector<1x16xf8E4M3FNUZ>, !shared_f8
@@ -963,26 +963,26 @@ util.func private @pingpong_medium_f8_expanded(%lhs_base: !mexp_in_ty_f8, %rhs_b
963963
%0 = tensor.empty() : tensor<1x8x16x16x16xf32>
964964
%1 = scf.forall (%id) in (512) shared_outs(%out = %0) -> tensor<1x8x16x16x16xf32> {
965965
%ids:4 = affine.delinearize_index %id into (2, 4, 4, 16) : index, index, index, index
966-
%inner_id = arith.muli %ids#2, %c8 : index
967-
%inner_id_acc = arith.muli %ids#2, %c4 : index
968-
%m_outer_id = arith.muli %ids#0, %c4 : index
969-
%n_outer_id = arith.muli %ids#1, %c4 : index
966+
%inner_id = arith.muli %ids#2, %c8 overflow<nsw, nuw> : index
967+
%inner_id_acc = arith.muli %ids#2, %c4 overflow<nsw, nuw> : index
968+
%m_outer_id = arith.muli %ids#0, %c4 overflow<nsw, nuw> : index
969+
%n_outer_id = arith.muli %ids#1, %c4 overflow<nsw, nuw> : index
970970
%delin:2 = affine.delinearize_index %id into (64, 8) : index, index
971971
%wt:3 = affine.delinearize_index %id into (8, 8, 8) : index, index, index
972972

973973
// Inner 64 loads 8 threads x 16 elements.
974-
%gko = arith.muli %wt#2, %c16 : index
974+
%gko = arith.muli %wt#2, %c16 overflow<nsw, nuw> : index
975975
// RHS indexing. Each subgroup loads 32 contiguous rows out of 256.
976-
%bpo = arith.muli %wt#0, %c32 : index
976+
%bpo = arith.muli %wt#0, %c32 overflow<nsw, nuw> : index
977977
// Base index is remaining outer 8 lanes + subgroup base.
978-
%glb0 = arith.addi %wt#1, %bpo : index
979-
%glb1 = arith.addi %glb0, %c8 : index
980-
%glb2 = arith.addi %glb1, %c8 : index
981-
%glb3 = arith.addi %glb2, %c8 : index
978+
%glb0 = arith.addi %wt#1, %bpo overflow<nsw, nuw> : index
979+
%glb1 = arith.addi %glb0, %c8 overflow<nsw, nuw> : index
980+
%glb2 = arith.addi %glb1, %c8 overflow<nsw, nuw> : index
981+
%glb3 = arith.addi %glb2, %c8 overflow<nsw, nuw> : index
982982
// LHS indexing.
983-
%bpo_lhs = arith.muli %wt#0, %c16 : index
984-
%glb0_lhs = arith.addi %wt#1, %bpo_lhs : index
985-
%glb1_lhs = arith.addi %glb0_lhs, %c8 : index
983+
%bpo_lhs = arith.muli %wt#0, %c16 overflow<nsw, nuw> : index
984+
%glb0_lhs = arith.addi %wt#1, %bpo_lhs overflow<nsw, nuw> : index
985+
%glb1_lhs = arith.addi %glb0_lhs, %c8 overflow<nsw, nuw> : index
986986

987987
%2 = arith.constant dense<0.0> : vector<4x4x1x4xf32>
988988

0 commit comments

Comments
 (0)