Skip to content

Commit d847b10

Browse files
authored
[LoadStoreOpToLLVM] Broadcast the result for atomic rmw & cas ops (#5031)
This PR fixes imporper result broadcasting for atomic rmw & cas ops reported in #4879 16bit atomic cas test cases remain skipped due to #5025
1 parent 73a1c3b commit d847b10

File tree

10 files changed

+38
-10
lines changed

10 files changed

+38
-10
lines changed

python/test/unit/language/test_core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2013,8 +2013,6 @@ def kernel(I, O):
20132013
@pytest.mark.parametrize("size", [1, 4, 16])
20142014
@pytest.mark.parametrize("op", ["add", "cas"])
20152015
def test_tensor_atomic_use_result(dtype_str, size, op, device):
2016-
if is_xpu():
2017-
pytest.skip("FIXME: issue #4879")
20182016
if is_hip():
20192017
pytest.skip(
20202018
"HIP is broken because (1) it doesn't support thread predicate in atomic cas, and (2) it doesn't support"

scripts/skiplist/a770/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,3 +751,7 @@ python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_batc
751751
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_batched_gemm_3d_tma
752752
# test_tensor_atomic_add_access_patterns
753753
python/test/unit/language/test_core.py::test_tensor_atomic_add_access_patterns[shape128-random_no_duplication-3-1-float32]
754+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
755+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
756+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
757+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

scripts/skiplist/arl-h/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,3 +609,7 @@ python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_redu
609609
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-min]
610610
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-or]
611611
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-xor]
612+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
613+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
614+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
615+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

scripts/skiplist/arl-s/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -609,3 +609,7 @@ python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_redu
609609
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-min]
610610
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-or]
611611
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-xor]
612+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
613+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
614+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
615+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

scripts/skiplist/conda/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,3 +235,7 @@ python/test/unit/language/test_core.py::test_const[if-False-False]
235235
python/test/unit/language/test_core.py::test_unroll_attr
236236
python/test/unit/language/test_decorator.py::test_triton_heuristic
237237
python/test/unit/language/test_core.py::test_constexpr_if_return
238+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
239+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
240+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
241+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

scripts/skiplist/default/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,7 @@ python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_redu
8686
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-min]
8787
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-or]
8888
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-xor]
89+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
90+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
91+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
92+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

scripts/skiplist/lts/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,3 +337,7 @@ python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_redu
337337
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-min]
338338
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-or]
339339
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-xor]
340+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
341+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
342+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
343+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

scripts/skiplist/mtl/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,3 +386,7 @@ python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_redu
386386
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-min]
387387
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-or]
388388
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-xor]
389+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
390+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
391+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
392+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

scripts/skiplist/xe2/language.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,7 @@ python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_redu
8383
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-min]
8484
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-or]
8585
python/test/unit/language/test_tensor_descriptor.py::test_tensor_descriptor_reduce[8-32-host-1-uint32-xor]
86+
# https://github.com/intel/intel-xpu-backend-for-triton/issues/5025
87+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-1-float16]
88+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-4-float16]
89+
python/test/unit/language/test_core.py::test_tensor_atomic_use_result[cas-16-float16]

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3237,10 +3237,9 @@ struct AtomicCASOpConversion
32373237
}
32383238

32393239
if (tensorTy) {
3240-
Type structTy = getTypeConverter()->convertType(tensorTy);
3241-
Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals,
3242-
rewriter, structTy);
3243-
rewriter.replaceOp(op, {resultStruct});
3240+
finalizeTensorAtomicResults(op, tensorTy, rewriter, resultVals,
3241+
valueElemTy, b, mask, targetInfo,
3242+
getTypeConverter());
32443243
}
32453244
return success();
32463245
}
@@ -3407,10 +3406,9 @@ struct AtomicRMWOpConversion
34073406
}
34083407

34093408
if (tensorTy) {
3410-
Type structTy = getTypeConverter()->convertType(tensorTy);
3411-
Value resultStruct = packLLElements(loc, getTypeConverter(), resultVals,
3412-
rewriter, structTy);
3413-
rewriter.replaceOp(op, {resultStruct});
3409+
finalizeTensorAtomicResults(op, tensorTy, rewriter, resultVals,
3410+
valueElemTy, b, threadPred, targetInfo,
3411+
getTypeConverter());
34143412
}
34153413
return success();
34163414
}

0 commit comments

Comments
 (0)