diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a7bdb3603c..5397bce010 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -3557,7 +3557,7 @@ def test_dot3d(B, num_warps, M, N, K, BLOCK_M, BLOCK_N, in_dtype_str, out_dtype_ if out_dtype_str == "float16": pytest.skip(f"{out_dtype_str} has low precision in WMMA dot") else: - input_precision = "tf32" if is_cuda() and in_dtype_str == 'float32' else "ieee" + input_precision = "tf32" if (is_cuda() or is_xpu()) and in_dtype_str == 'float32' else "ieee" if B == 8 and M == 64 and in_dtype_str == "float32" and out_dtype_str == "float32": if not is_interpreter() and torch.cuda.is_available( diff --git a/scripts/skiplist/a770/language.txt b/scripts/skiplist/a770/language.txt index 6682e5d059..359d590de7 100644 --- a/scripts/skiplist/a770/language.txt +++ b/scripts/skiplist/a770/language.txt @@ -1,166 +1,5 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-4-128-128-64-64-64-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8] # https://github.com/intel/intel-xpu-backend-for-triton/issues/983 test/unit/language/test_core.py::test_noinline[shared] test/unit/language/test_core.py::test_dot[1-128-128-64-2-True-True-none-tf32-int8-int8-1_0] diff --git a/scripts/skiplist/default/language.txt b/scripts/skiplist/default/language.txt index cdac848de1..fd1a7e0a48 100644 --- a/scripts/skiplist/default/language.txt +++ b/scripts/skiplist/default/language.txt @@ -1,163 +1,2 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-4-128-128-64-64-64-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8] diff --git a/scripts/skiplist/xe2/language.txt b/scripts/skiplist/xe2/language.txt index cdac848de1..fd1a7e0a48 100644 --- a/scripts/skiplist/xe2/language.txt +++ b/scripts/skiplist/xe2/language.txt @@ -1,163 +1,2 @@ # https://github.com/intel/intel-xpu-backend-for-triton/issues/1434 test/unit/language/test_core.py::test_precise_math[1-tl.math.sqrt_rn(x)-tl.math.sqrt(x.to(tl.float64)).to(tl.float32)] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[1-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[2-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-4-128-128-64-64-64-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[4-8-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-1-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-1-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-16-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-16-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-2-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-2-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-4-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-4-64-64-64-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-8-32-32-32-32-32-int8-int8] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float16] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float16-float32] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-float32-float32] -test/unit/language/test_core.py::test_dot3d[8-8-64-64-64-32-32-int8-int8] diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index c6f314aa50..28e374932b 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -102,51 +102,80 @@ SmallVector DpasEncodingAttr::getDPASInstShapeC() const { }; SmallVector DpasEncodingAttr::getShapeA() const { - auto shapeA = getDPASInstShapeA(); + auto instShapeA = getDPASInstShapeA(); auto repCluster = getRepCluster(); - return {shapeA[0] * repCluster[0], shapeA[1]}; + size_t rank = repCluster.size(); + SmallVector resShape(rank, 1); + resShape[rank - 2] = instShapeA[0] * repCluster[rank - 2]; + resShape[rank - 1] = instShapeA[1]; + return resShape; } SmallVector DpasEncodingAttr::getShapeB() const { - auto shapeB = getDPASInstShapeB(); + auto instShapeB = getDPASInstShapeB(); auto repCluster = getRepCluster(); - return {shapeB[0], shapeB[1] * repCluster[1]}; + size_t rank = repCluster.size(); + SmallVector resShape(rank, 1); + resShape[rank - 2] = instShapeB[0]; + resShape[rank - 1] = instShapeB[1] * repCluster[rank - 1]; + return resShape; } SmallVector DpasEncodingAttr::getShapeC() const { - auto shapeC = getDPASInstShapeC(); + auto instShapeC = getDPASInstShapeC(); auto repCluster = getRepCluster(); - return {shapeC[0] * repCluster[0], shapeC[1] * repCluster[1]}; + size_t rank = repCluster.size(); + SmallVector resShape(rank, 1); + resShape[rank - 2] = instShapeC[0] * repCluster[rank - 2]; + resShape[rank - 1] = instShapeC[1] * repCluster[rank - 1]; + return resShape; } SmallVector DpasEncodingAttr::getSizePerThread() const { + size_t rank = getWarpsPerCTA().size(); + SmallVector res(rank, 1); unsigned threadsPerWarp = getSubGroupSize(); auto shapeC = getDPASInstShapeC(); unsigned elemsNum = product(shapeC); unsigned elemsPerThread = elemsNum / threadsPerWarp; auto repCluster = getRepCluster(); // The Value is shard to lanes to threads per DPAS instruction. - return {elemsPerThread * repCluster[0], repCluster[1]}; + if (rank == 3) + res[0] = repCluster[0]; + res[rank - 2] = elemsPerThread * repCluster[rank - 2]; + res[rank - 1] = repCluster[rank - 1]; + return res; } SmallVector DpasEncodingAttr::getShapePerCTATile(ArrayRef tensorShape) const { auto shapeC = getShapeC(); - return {shapeC[0] * getWarpsPerCTA()[0], shapeC[1] * getWarpsPerCTA()[1]}; + SmallVector warpsPerCTA = getWarpsPerCTA(); + size_t rank = shapeC.size(); + SmallVector shapePerCTATile(rank); + llvm::transform( + llvm::zip_equal(shapeC, warpsPerCTA), shapePerCTATile.begin(), + [](auto entry) { return std::get<0>(entry) * std::get<1>(entry); }); + return shapePerCTATile; } SmallVector DpasEncodingAttr::getElemsPerThread(ArrayRef shape, Type eltTy) const { size_t rank = shape.size(); - assert(rank == 2 && "Unexpected rank of mma layout"); + assert((rank == 2 || rank == 3) && "Unexpected rank of mma layout"); - SmallVector elemsPerThread(rank); + SmallVector elemsPerThread(rank, 1); auto shapePerCTATile = getShapePerCTATile(shape); - unsigned tilesRow = ceil(shape[0], shapePerCTATile[0]); - unsigned tilesCol = ceil(shape[1], shapePerCTATile[1]); + unsigned tilesRow = + ceil(shape[rank - 2], shapePerCTATile[rank - 2]); + unsigned tilesCol = + ceil(shape[rank - 1], shapePerCTATile[rank - 1]); auto sizePerThread = getSizePerThread(); - elemsPerThread[0] = sizePerThread[0] * tilesRow; - elemsPerThread[1] = sizePerThread[1] * tilesCol; + if (rank == 3) + elemsPerThread[0] = + sizePerThread[0] * ceil(shape[0], shapePerCTATile[0]); + elemsPerThread[rank - 2] = sizePerThread[rank - 2] * tilesRow; + elemsPerThread[rank - 1] = sizePerThread[rank - 1] * tilesCol; return elemsPerThread; } @@ -157,41 +186,65 @@ unsigned DpasEncodingAttr::getTotalElemsPerThread(ArrayRef shape, } SmallVector DpasEncodingAttr::getCTASplitNum() const { - SmallVector res{1, 1}; + size_t rank = getWarpsPerCTA().size(); + SmallVector res(rank, 1); return res; } SmallVector DpasEncodingAttr::getCTAOrder() const { - SmallVector res{1, 0}; + size_t rank = getWarpsPerCTA().size(); + auto res = llvm::to_vector(llvm::reverse(llvm::seq(rank))); return res; } SmallVector DpasEncodingAttr::getCTAsPerCGA() const { - SmallVector res{1, 1}; + size_t rank = getWarpsPerCTA().size(); + SmallVector res(rank, 1); return res; } SmallVector DpasEncodingAttr::getDPASRepetitions(ArrayRef shape, int opIdx) const { + // Always return a 3D shape repetitions for the ease of value handling, same + // to mma. auto warpsPerCTA = getWarpsPerCTA(); + int rank = shape.size(); + SmallVector rep(3, 1); if (opIdx == 0) { auto shapePerWarp = getShapeA(); - return {std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])), - std::max(1, shape[1] / shapePerWarp[1])}; - } else if (opIdx == 1) { + int64_t numRepBatch = + rank == 3 ? std::max(1, shape[0] / + (shapePerWarp[0] * warpsPerCTA[0])) + : 1; + return {numRepBatch, + std::max(1, shape[rank - 2] / (shapePerWarp[rank - 2] * + warpsPerCTA[rank - 2])), + std::max(1, shape[rank - 1] / shapePerWarp[rank - 1])}; + } + + if (opIdx == 1) { auto shapePerWarp = getShapeB(); - return { - std::max(1, shape[0] / shapePerWarp[0]), - std::max(1, shape[1] / (shapePerWarp[1] * warpsPerCTA[1]))}; - } else { - assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)"); - auto shapePerWarp = getShapeC(); - return { - std::max(1, mlir::ceil( - shape[0], shapePerWarp[0] * warpsPerCTA[0])), - std::max(1, mlir::ceil( - shape[1], shapePerWarp[1] * warpsPerCTA[1]))}; + int64_t numRepBatch = + rank == 3 ? std::max(1, shape[0] / + (shapePerWarp[0] * warpsPerCTA[0])) + : 1; + return {numRepBatch, + std::max(1, shape[rank - 2] / shapePerWarp[rank - 2]), + std::max(1, shape[rank - 1] / (shapePerWarp[rank - 1] * + warpsPerCTA[rank - 1]))}; } + + assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)"); + auto shapePerWarp = getShapeC(); + int64_t numRepBatch = + rank == 3 + ? std::max(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])) + : 1; + return {numRepBatch, + std::max(1, shape[rank - 2] / (shapePerWarp[rank - 2] * + warpsPerCTA[rank - 2])), + std::max(1, shape[rank - 1] / (shapePerWarp[rank - 1] * + warpsPerCTA[rank - 1]))}; } unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands( @@ -199,23 +252,30 @@ unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands( auto shapePerCTA = getShapePerCTA(*this, shape); auto rep = getDPASRepetitions(shapePerCTA, opIdx); auto threadsPerWar = getSubGroupSize(); + size_t rank = shape.size(); if (opIdx == 0) { auto shapeA = getShapeA(); auto totalElem = product(shapeA); // dpas operands scalar are evenly sharded to each work item. - return (totalElem / threadsPerWar) * rep[0] * rep[1]; - } else { // if (opIdx == 1) + return (totalElem / threadsPerWar) * product(rep); + } + if (opIdx == 1) { auto shapeB = getShapeB(); auto totalElem = product(shapeB); // dpas operands scalar are evenly sharded to each work item. - return (totalElem / threadsPerWar) * rep[0] * rep[1]; + return (totalElem / threadsPerWar) * product(rep); } + llvm_unreachable("DpasEncodingAttr opIdx must be 0 or 1"); } -SmallVector DpasEncodingAttr::getWarpOrder() const { return {1, 0}; } +SmallVector DpasEncodingAttr::getWarpOrder() const { + size_t rank = getWarpsPerCTA().size(); + return llvm::to_vector(llvm::reverse(llvm::seq(rank))); +} SmallVector DpasEncodingAttr::getThreadOrder() const { - return {1, 0}; + size_t rank = getWarpsPerCTA().size(); + return llvm::to_vector(llvm::reverse(llvm::seq(rank))); } SmallVector DpasEncodingAttr::getWarpsPerCTA() const { @@ -224,33 +284,51 @@ SmallVector DpasEncodingAttr::getWarpsPerCTA() const { } SmallVector DpasEncodingAttr::getThreadsPerWarp() const { + size_t rank = getWarpsPerCTA().size(); + SmallVector res(rank, 1); auto executionSize = getExecutionSize(); auto subGroupSize = getSubGroupSize(); if (subGroupSize < executionSize) { llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be " "smaller than the execution size"); } - return {subGroupSize / executionSize, executionSize}; + res[rank - 2] = subGroupSize / executionSize; + res[rank - 1] = executionSize; + return res; } SmallVector DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef shape, int opIdx) const { auto parentShapePerCTATile = getShapePerCTATile(shape); - auto threadsPerWarp = getThreadsPerWarp(); + size_t rank = parentShapePerCTATile.size(); + assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout"); if (opIdx == 0) { auto shapeA = getShapeA(); - return {parentShapePerCTATile[0], shapeA[1]}; - } else if (opIdx == 1) { + return (rank == 2) + ? SmallVector{parentShapePerCTATile[0], shapeA[1]} + : SmallVector{parentShapePerCTATile[0], + parentShapePerCTATile[rank - 2], + shapeA[rank - 1]}; + } + + if (opIdx == 1) { auto shapeB = getShapeB(); - return {shapeB[0], parentShapePerCTATile[1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); + return (rank == 2) + ? SmallVector{shapeB[0], parentShapePerCTATile[1]} + : SmallVector{parentShapePerCTATile[0], + shapeB[rank - 2], + parentShapePerCTATile[rank - 1]}; } + + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } SmallVector DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { + ArrayRef repCluster = getRepCluster(); + size_t rank = repCluster.size(); + assert((rank == 2 || rank == 3) && "unexpected rank number for Dpas layout"); if (opIdx == 0) { SmallVector shapeA = getDPASInstShapeA(); unsigned subGroupSize = getSubGroupSize(); @@ -267,9 +345,10 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { "be smaller than the threads required per row."); } unsigned rowsPerWarp = mlir::ceil(subGroupSize, packedColNum); - auto repCluster = getRepCluster(); - return {shapeA[0] / rowsPerWarp * repCluster[0], packedOpsPerLane}; - } else if (opIdx == 1) { + return {shapeA[0] / rowsPerWarp * repCluster[rank - 2], packedOpsPerLane}; + } + + if (opIdx == 1) { auto shapeB = getShapeB(); auto subGroupSize = getSubGroupSize(); auto executionSize = getExecutionSize(); @@ -279,13 +358,11 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const { } SmallVector threadsPerWarp = {subGroupSize / executionSize, executionSize}; - auto repCluster = getRepCluster(); - return {shapeB[0] / threadsPerWarp[0], - shapeB[1] / threadsPerWarp[1] * repCluster[1]}; - } else { - llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); - return {}; + return {shapeB[rank - 2] / threadsPerWarp[0], + shapeB[rank - 1] / threadsPerWarp[1] * repCluster[rank - 1]}; } + + llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1"); } SmallVector DpasEncodingAttr::getElemsPerThreadForOperands( @@ -293,25 +370,38 @@ SmallVector DpasEncodingAttr::getElemsPerThreadForOperands( SmallVector sizePerThread = getSizePerThreadForOperands(opIdx); SmallVector repetitions = getDPASRepetitions(shape, opIdx); - return {static_cast(sizePerThread[0] * repetitions[0]), - static_cast(sizePerThread[1] * repetitions[1])}; + size_t rank = shape.size(); + SmallVector elemsPerThread(rank); + if (rank == 3) + elemsPerThread[0] = repetitions[0]; + elemsPerThread[rank - 2] = sizePerThread[0] * repetitions[1]; + elemsPerThread[rank - 1] = sizePerThread[1] * repetitions[2]; + + return elemsPerThread; }; SmallVector DpasEncodingAttr::getContigPerThread() { + size_t rank = getWarpsPerCTA().size(); + assert(rank == 2 || rank == 3); + SmallVector contigPerThread(rank, 1); + unsigned threadsPerWarp = getSubGroupSize(); - auto shapeC = getDPASInstShapeC(); + auto instShapeC = getDPASInstShapeC(); // The software vectorization vectorized the value as C array: int a[N] -> int // a[N][threadsPerWarp] - if (threadsPerWarp > shapeC[1]) { - return {1, 1}; - } else if (threadsPerWarp == shapeC[1]) { + if (threadsPerWarp > instShapeC[1]) { + return contigPerThread; + } + + if (threadsPerWarp == instShapeC[1]) { auto repCluster = getRepCluster(); - return {shapeC[0] * repCluster[0], 1}; - } else { - // threadsPerWarp < shapeC[1] - llvm::report_fatal_error("DpasEncodingAttr sub-group size could not " - "be smaller than the threads required per row."); + contigPerThread[rank - 2] = instShapeC[0] * repCluster[rank - 2]; + return contigPerThread; } + + // threadsPerWarp < shapeC[1] + llvm::report_fatal_error("DpasEncodingAttr sub-group size could not " + "be smaller than the threads required per row."); } LogicalResult DpasEncodingAttr::verify( @@ -333,8 +423,8 @@ LogicalResult DpasEncodingAttr::verify( return emitError() << "systolicDepth must be 8, but was:" << opsPerChan; } - if (repCluster.size() != 2) { - return emitError() << "expected rank 2 of repCluster, but the rank is:" + if (!(repCluster.size() == 2 || repCluster.size() == 3)) { + return emitError() << "expected rank 2 or 3 of repCluster, but the rank is:" << repCluster.size(); } diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp index 90e950bd0c..4ee77e934d 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp @@ -275,7 +275,6 @@ LinearLayout ensureLayoutNotSmallerThan( return layout; } - MLIRContext *ctx = shape.begin()->first.getContext(); StringAttr kDim = *layout.getInDimNames().begin(); assert(kDim == "register" || kDim == "offset" && "unexpected kDim"); @@ -491,7 +490,8 @@ LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, assert(dpas && "Must be DPAS layout"); int rank = shape.size(); - assert(rank == dpas.getWarpsPerCTA().size() && rank == 2 && "Invalid rank"); + assert(rank == dpas.getWarpsPerCTA().size() && (rank == 2 || rank == 3) && + "Invalid rank"); MLIRContext *ctx = dpas.getContext(); SmallVector outDimNames = standardOutDimNames(ctx, rank); @@ -508,25 +508,28 @@ LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, int systolicDepth = dpas.getSystolicDepth(); int repeatCount = dpas.getRepeatCount(); int executionSize = dpas.getExecutionSize(); - unsigned KDim = 0; - unsigned nonKDim = 0; + unsigned KDim, nonKDim; if (opIdx == 0) { // Operand A auto regBasesA = DPASRegBasesA(opsPerChannel, repeatCount, threadsPerWarp, systolicDepth); auto laneBasesA = DPASLaneBasesA(opsPerChannel, threadsPerWarp, systolicDepth); tileLayout = LinearLayout({{kRegister, regBasesA}, {kLane, laneBasesA}}, - outDimNames); - // A only repeats by repCluster[0] - tileLayout *= - LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]); + ArrayRef(outDimNames).take_back(2)); + // A only repeats by repCluster[rank - 2] + nonKDim = rank - 2; + KDim = rank - 1; + tileLayout *= LinearLayout::identity1D(repCluster[nonKDim], kRegister, + outDimNames[nonKDim]); - nonKDim = 0; - KDim = 1; // K-dimension is shared among warps - tileLayout *= LinearLayout::zeros1D(warpsPerCTA[1], kWarp, outDimNames[1]); tileLayout *= - LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); + LinearLayout::zeros1D(warpsPerCTA[KDim], kWarp, outDimNames[KDim]); + tileLayout *= LinearLayout::identity1D(warpsPerCTA[nonKDim], kWarp, + outDimNames[nonKDim]); + if (rank == 3) + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); } else if (opIdx == 1) { // Operand B auto regBasesB = DPASRegBasesB(opsPerChannel, executionSize, threadsPerWarp, @@ -534,49 +537,62 @@ LinearLayout DPAStoLinearLayout(ArrayRef shape, Attribute layout, auto laneBasesB = DPASLaneBasesB(opsPerChannel, threadsPerWarp, executionSize); tileLayout = LinearLayout({{kRegister, regBasesB}, {kLane, laneBasesB}}, - outDimNames); - // B only repeats by repCluster[1] - tileLayout *= - LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]); - - nonKDim = 1; - KDim = 0; + ArrayRef(outDimNames).take_back(2)); + // B only repeats by repCluster[rank - 1] + nonKDim = rank - 1; + KDim = rank - 2; + tileLayout *= LinearLayout::identity1D(repCluster[nonKDim], kRegister, + outDimNames[nonKDim]); // K-dimension is shared among warps + tileLayout *= LinearLayout::identity1D(warpsPerCTA[nonKDim], kWarp, + outDimNames[nonKDim]); tileLayout *= - LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]); - tileLayout *= LinearLayout::zeros1D(warpsPerCTA[0], kWarp, outDimNames[0]); + LinearLayout::zeros1D(warpsPerCTA[KDim], kWarp, outDimNames[KDim]); + if (rank == 3) + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); } else { // opIdx=2 -> Operand C auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp); auto laneBasesC = DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp); tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}}, - outDimNames); + ArrayRef(outDimNames).take_back(2)); // The per-inst layout is repeated at each repCluster. // Hence, multiply with the identity layouts starting from the // least significant dimension. + nonKDim = rank - 2; + KDim = rank - 1; + tileLayout *= LinearLayout::identity1D(repCluster[KDim], kRegister, + outDimNames[KDim]); + tileLayout *= LinearLayout::identity1D(repCluster[nonKDim], kRegister, + outDimNames[nonKDim]); + + // // The identical layout is repeated among warps tileLayout *= - LinearLayout::identity1D(repCluster[1], kRegister, outDimNames[1]); - tileLayout *= - LinearLayout::identity1D(repCluster[0], kRegister, outDimNames[0]); - - // The identical layout is repeated among warps - tileLayout *= - LinearLayout::identity1D(warpsPerCTA[1], kWarp, outDimNames[1]); - tileLayout *= - LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); - nonKDim = 0; - KDim = 1; + LinearLayout::identity1D(warpsPerCTA[KDim], kWarp, outDimNames[KDim]); + tileLayout *= LinearLayout::identity1D(warpsPerCTA[nonKDim], kWarp, + outDimNames[nonKDim]); + if (rank == 3) + tileLayout *= + LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]); } // Lastly, the layout repeats to match the shape. // Operand A/B repeats through the K-dimension first then repeats // through the non-K dimension. SmallVector numReps = dpas.getDPASRepetitions(shape, opIdx); + + // numReps is always 3D, we should add 1 to dim id when rank is 2 + int repDimK = rank == 2 ? KDim + 1 : KDim; + int repDimNonK = rank == 2 ? nonKDim + 1 : nonKDim; tileLayout *= - LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]); - tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister, + LinearLayout::identity1D(numReps[repDimK], kRegister, outDimNames[KDim]); + tileLayout *= LinearLayout::identity1D(numReps[repDimNonK], kRegister, outDimNames[nonKDim]); + if (rank == 3) + tileLayout *= + LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]); return combineCtaCgaWithShape(std::move(tileLayout), CTALayoutAttr::getDefault(ctx, rank), shape); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp index f5a4bf0bad..d310f82e3c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -112,17 +112,24 @@ struct ConvertLayoutOpConversion return multiDimOffset; } if (auto dpasLayout = dyn_cast(layout)) { - assert(rank == 2); + assert((rank == 2 || rank == 3) && + "unexpected rank number for Dpas layout"); auto multiDimBase = ::intel::emitBaseIndexForLayout( loc, rewriter, targetInfo, layout, type, false); SmallVector> offsets; ::emitOffsetForDpasLayoutPerCTA( - dpasLayout, offsets, multiDimCTAInRepId[0] * shapePerCTATile[0], - multiDimCTAInRepId[1] * shapePerCTATile[1]); + dpasLayout, offsets, + multiDimCTAInRepId[rank - 2] * shapePerCTATile[rank - 2], + multiDimCTAInRepId[rank - 1] * shapePerCTATile[rank - 1]); - SmallVector multiDimOffset = { - add(multiDimBase[0], i32_val(offsets[elemId][0])), - add(multiDimBase[1], i32_val(offsets[elemId][1]))}; + SmallVector multiDimOffset(rank); + if (rank == 3) + multiDimOffset[0] = add(multiDimBase[0], i32_val(multiDimCTAInRepId[0] * + shapePerCTATile[0])); + multiDimOffset[rank - 2] = + add(multiDimBase[rank - 2], i32_val(offsets[elemId][rank - 2])); + multiDimOffset[rank - 1] = + add(multiDimBase[rank - 1], i32_val(offsets[elemId][rank - 1])); return multiDimOffset; } @@ -315,7 +322,7 @@ struct ConvertLayoutOpConversion return success(); } - using ValueTable = std::map, Value>; + using ValueTable = std::map, Value>; ValueTable getValuesFromDpasLayoutStruct(Location loc, ConversionPatternRewriter &rewriter, @@ -334,20 +341,26 @@ struct ConvertLayoutOpConversion SmallVector repetitions = dpasLayout.getDPASRepetitions(srcType.getShape(), 2 /*operand C*/); ArrayRef repCluster = dpasLayout.getRepCluster(); + size_t rank = repCluster.size(); + size_t outerDim = rank - 2; + size_t innerDim = rank - 1; int offset = 0; ValueTable result; - for (int i = 0; i < repetitions[0]; ++i) { - for (int j = 0; j < repetitions[1]; ++j) { - for (int repOuter = 0; repOuter < repCluster[0]; ++repOuter) { - for (int repInner = 0; repInner < repCluster[1]; ++repInner) { - Value matVal = rewriter.create(loc, dotOpTy); - for (int k = 0; k < numElemsPerOperand; ++k) { - matVal = - insert_element(dotOpTy, matVal, elems[offset++], i32_val(k)); + for (unsigned b = 0; b < repetitions[0]; ++b) { + for (int i = 0; i < repetitions[1]; ++i) { + for (int j = 0; j < repetitions[2]; ++j) { + for (int repOuter = 0; repOuter < repCluster[outerDim]; ++repOuter) { + for (int repInner = 0; repInner < repCluster[innerDim]; + ++repInner) { + Value matVal = rewriter.create(loc, dotOpTy); + for (int k = 0; k < numElemsPerOperand; ++k) { + matVal = insert_element(dotOpTy, matVal, elems[offset++], + i32_val(k)); + } + result[{b, i * repCluster[outerDim] + repOuter, + j * repCluster[innerDim] + repInner}] = matVal; } - result[{i * repCluster[0] + repOuter, - j * repCluster[1] + repInner}] = matVal; } } } @@ -365,35 +378,39 @@ struct ConvertLayoutOpConversion SmallVector repetitions = dpasLayout.getDPASRepetitions(dstType.getShape(), opIdx); ArrayRef repCluster = dpasLayout.getRepCluster(); + size_t rank = repCluster.size(); + unsigned repBatch = repetitions[0]; unsigned repOuter = 0u; unsigned repInner = 0u; unsigned repClusterOuter = 0u; if (opIdx == 0) { // operand A - repOuter = repetitions[0]; - repInner = repetitions[1]; - repClusterOuter = repCluster[0]; + repOuter = repetitions[1]; + repInner = repetitions[2]; + repClusterOuter = repCluster[rank - 2]; } else { // operand B - repOuter = repetitions[1]; - repInner = repetitions[0]; - repClusterOuter = repCluster[1]; + repOuter = repetitions[2]; + repInner = repetitions[1]; + repClusterOuter = repCluster[rank - 1]; } // TODO: Operands B requires extra steps to combine [8, 16] to [16, 16]. SmallVector elems; - for (int m = 0; m < repOuter; ++m) { - for (int k = 0; k < repInner; ++k) { - for (int repOuterIdx = 0; repOuterIdx < repClusterOuter; - ++repOuterIdx) { - unsigned offsetM = m * repClusterOuter + repOuterIdx; - unsigned offsetN = k; - Value matVal = vals.at({offsetM, offsetN}); - VectorType vecType = cast(matVal.getType()); - Type valTy = vecType.getElementType(); - for (int i = 0; i < vecType.getNumElements(); ++i) { - Value val = extract_element(valTy, matVal, i32_val(i)); - elems.push_back(val); + for (unsigned b = 0; b < repBatch; ++b) { + for (int m = 0; m < repOuter; ++m) { + for (int k = 0; k < repInner; ++k) { + for (int repOuterIdx = 0; repOuterIdx < repClusterOuter; + ++repOuterIdx) { + unsigned offsetM = m * repClusterOuter + repOuterIdx; + unsigned offsetN = k; + Value matVal = vals.at({b, offsetM, offsetN}); + VectorType vecType = cast(matVal.getType()); + Type valTy = vecType.getElementType(); + for (int i = 0; i < vecType.getNumElements(); ++i) { + Value val = extract_element(valTy, matVal, i32_val(i)); + elems.push_back(val); + } } } } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp index da4480206a..49d88dca1f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp @@ -3,7 +3,7 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "llvm/Support/ErrorHandling.h" -using ValueTable = std::map, Value>; +using ValueTable = std::map, Value>; using mlir::triton::gpu::getShapePerCTA; using mlir::triton::gpu::SharedEncodingAttr; using mlir::triton::gpu::intel::DpasEncodingAttr; @@ -16,17 +16,22 @@ template class DpasMatmulLoader { DpasMatmulLoader(DpasEncodingAttr dpasLayout, MemDescType descTy, unsigned warpsPerTile, ArrayRef smemStrides, const SmallVector &warpShape, + SmallVector multiDimWarpId, ConversionPatternRewriter &rewriter, const LLVMTypeConverter *typeConverter, Location loc) : dpasLayout(dpasLayout), descTy(descTy), smemStrides(smemStrides), - rewriter(rewriter), loc(loc) { + multiDimWarpId(multiDimWarpId), rewriter(rewriter), loc(loc) { static_assert(opIdx == 0 || opIdx == 1); - unsigned kDim = (opIdx == 0) ? 1 : 0; + size_t rank = warpShape.size(); + unsigned kDim = opIdx ? rank - 2 : rank - 1; + unsigned nonKDim = opIdx ? rank - 1 : rank - 2; + // Assume that smem is create with layout offset {2, 1, 0} + repBatchDimStride = smemStrides[0]; repKDimStride = mul(i32_val(warpShape[kDim]), smemStrides[kDim]); repNonKDimStride = - mul(i32_val(warpShape[kDim ^ 1] * warpsPerTile), smemStrides[kDim ^ 1]); - warpMatStride = mul(i32_val(warpShape[kDim ^ 1]), smemStrides[kDim ^ 1]); + mul(i32_val(warpShape[nonKDim] * warpsPerTile), smemStrides[nonKDim]); + warpMatStride = mul(i32_val(warpShape[nonKDim]), smemStrides[nonKDim]); unsigned threadsPerWarp = getThreadsPerWarp(); @@ -44,9 +49,9 @@ template class DpasMatmulLoader { SmallVector computeLdsMatOffs(Value warpOff, Value lane, Value cSwizzleOffset); // Load the matrix value. - Value loadMatrix(int repOuter, int repInner, const ArrayRef ptrs, - LLVM::LLVMStructType structTy, Type smemTy, - Value cSwizzleOffset) const; + Value loadMatrix(int repBatch, int repOuter, int repInner, + const ArrayRef ptrs, LLVM::LLVMStructType structTy, + Type smemTy, Value cSwizzleOffset) const; private: unsigned getThreadsPerWarp() const { @@ -57,6 +62,8 @@ template class DpasMatmulLoader { MemDescType descTy; SmallVector smemStrides; + SmallVector multiDimWarpId; + Value repBatchDimStride; Value repNonKDimStride; Value repKDimStride; @@ -77,6 +84,7 @@ DpasMatmulLoader::computeLdsMatOffs(Value warpId, Value laneId, unsigned executionSize = dpasLayout.getExecutionSize(); unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); unsigned threadsPerWarp = getThreadsPerWarp(); + unsigned rank = dpasLayout.getRepCluster().size(); Value laneRowIndex, laneColIndex; unsigned rowsPerInst = 0u, rowsPerWarp = 0u, packedOpsPerLane = 0u; @@ -87,7 +95,7 @@ DpasMatmulLoader::computeLdsMatOffs(Value warpId, Value laneId, SmallVector shapeA = dpasLayout.getShapeA(); // Unlike the operand B, to pack the value to i16 for scalar bit width <=16. packedOpsPerLane = opsPerChannel == 4 ? 2 : 1; - unsigned packedColNum = shapeA[1] / packedOpsPerLane; + unsigned packedColNum = shapeA[rank - 1] / packedOpsPerLane; assert(threadsPerWarp >= packedColNum && "DpasEncodingAttr sub-group size could not " "be smaller than the threads required per row for A operand."); @@ -128,9 +136,11 @@ DpasMatmulLoader::computeLdsMatOffs(Value warpId, Value laneId, SmallVector instShape = opIdx == 0 ? dpasLayout.getDPASInstShapeA() : dpasLayout.getDPASInstShapeB(); ArrayRef shareMemoryShape = descTy.getShape(); + SmallVector shapePerCTA = getShapePerCTA(descTy); SmallVector offs(numPtrs); - const unsigned repClusterSize = dpasLayout.getRepCluster()[opIdx]; + const unsigned repClusterSize = + dpasLayout.getRepCluster()[opIdx ? rank - 1 : rank - 2]; unsigned index = 0u; for (unsigned repIdx = 0; repIdx < repClusterSize; ++repIdx) { unsigned repIndex = repIdx * instShape[opIdx]; @@ -154,8 +164,8 @@ DpasMatmulLoader::computeLdsMatOffs(Value warpId, Value laneId, // round the offset into the tensor's shape limitation. (Rounded // broadcast) - iBase = urem(iBase, i32_val(shareMemoryShape[0])); - jBase = urem(jBase, i32_val(shareMemoryShape[1])); + iBase = urem(iBase, i32_val(shareMemoryShape[rank - 2])); + jBase = urem(jBase, i32_val(shareMemoryShape[rank - 1])); // inner index offset Value jOff = zeroVal; @@ -164,10 +174,17 @@ DpasMatmulLoader::computeLdsMatOffs(Value warpId, Value laneId, jOff = add(jOff, udiv(cSwizzleOffset, vecVal)); jOff = mul(xor_(jOff, phase), vecVal); - Value i = add(mul(iBase, smemStrides[0]), iOff); - Value j = add(mul(jBase, smemStrides[1]), jOff); + Value i = add(mul(iBase, smemStrides[rank - 2]), iOff); + Value j = add(mul(jBase, smemStrides[rank - 1]), jOff); - offs[index++] = add(i, j); + Value baseOff; + if (shapePerCTA.size() == 3 && shapePerCTA[0] > 1) { + Value batchOffset = + mul(multiDimWarpId[0], i32_val(shapePerCTA[1] * shapePerCTA[2])); + offs[index++] = add(batchOffset, add(i, j)); + } else { + offs[index++] = add(i, j); + } } } } @@ -176,11 +193,9 @@ DpasMatmulLoader::computeLdsMatOffs(Value warpId, Value laneId, } template -Value DpasMatmulLoader::loadMatrix(int repOuter, int repInner, - const ArrayRef ptrs, - LLVM::LLVMStructType structTy, - Type smemTy, - Value cSwizzleOffset) const { +Value DpasMatmulLoader::loadMatrix( + int repBatch, int repOuter, int repInner, const ArrayRef ptrs, + LLVM::LLVMStructType structTy, Type smemTy, Value cSwizzleOffset) const { Type elemTy = structTy.getBody()[0]; assert( llvm::any_of(structTy.getBody(), [&](Type ty) { return ty == elemTy; }) && @@ -189,6 +204,12 @@ Value DpasMatmulLoader::loadMatrix(int repOuter, int repInner, Value offsetOuter = mul(i32_val(repOuter), repNonKDimStride); Value offsetInner = mul(i32_val(repInner), repKDimStride); Value offset = add(offsetOuter, offsetInner); + if (repBatch > 0) { + SmallVector warpsPerCTA = dpasLayout.getWarpsPerCTA(); + Value offsetBatch = + mul(i32_val(repBatch * warpsPerCTA[0]), repBatchDimStride); + offset = add(offset, offsetBatch); + } Value llvmStruct = rewriter.create(loc, structTy); size_t elemNum = structTy.getBody().size(); @@ -203,18 +224,20 @@ Value DpasMatmulLoader::loadMatrix(int repOuter, int repInner, } Value composeValuesToDotOperandLayoutStruct( - const ValueTable &vals, int n0, int n1, + const ValueTable &vals, int batch, int n0, int n1, const LLVMTypeConverter *typeConverter, Location loc, ConversionPatternRewriter &rewriter) { std::vector elems; - for (int m = 0; m < n0; ++m) { - for (int k = 0; k < n1; ++k) { - Value matVal = vals.at({m, k}); - auto matType = cast(matVal.getType()); - Type valTy = matType.getBody()[0]; - for (int i = 0; i < matType.getBody().size(); ++i) { - auto val = extract_val(valTy, matVal, i); - elems.push_back(val); + for (int b = 0; b < batch; ++b) { + for (int m = 0; m < n0; ++m) { + for (int k = 0; k < n1; ++k) { + Value matVal = vals.at({b, m, k}); + auto matType = cast(matVal.getType()); + Type valTy = matType.getBody()[0]; + for (int i = 0; i < matType.getBody().size(); ++i) { + auto val = extract_val(valTy, matVal, i); + elems.push_back(val); + } } } } @@ -245,10 +268,11 @@ Type getSharedMemTy(Type argType) { } template -std::function +std::function getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, DpasEncodingAttr dpasLayout, unsigned warpsPerTile, - SmallVector instrShape, Value warpId, + SmallVector shapePerWarp, + SmallVector multiDimWarpId, Value warpId, Value outerWarpDim, Value laneId, ValueTable &vals, const LLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc) { @@ -259,12 +283,14 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, auto sharedLayout = cast(descTy.getEncoding()); ArrayRef order = sharedLayout.getOrder(); + size_t rank = order.size(); // (a, b) is the coordinate. - auto load = [=, &rewriter, &smemObj, &instrShape, &vals](int a, int b) { - DpasMatmulLoader loader(dpasLayout, descTy, warpsPerTile, - smemObj.strides, instrShape, rewriter, - typeConverter, loc); + auto load = [=, &rewriter, &smemObj, &shapePerWarp, &multiDimWarpId, + &vals](int batch, int outer, int inner) { + DpasMatmulLoader loader( + dpasLayout, descTy, warpsPerTile, smemObj.strides, shapePerWarp, + multiDimWarpId, rewriter, typeConverter, loc); // Offset of a slice within the original tensor in shared memory. Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); @@ -282,14 +308,15 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj, gep(ptr_ty(rewriter.getContext(), 3), smemTy, smemBase, offs[i]); // Load from shared memory. - unsigned totalElem = product(instrShape); + unsigned totalElem = product(shapePerWarp); unsigned threadsPerWarp = product(getThreadsPerWarp(dpasLayout)); auto matTy = LLVM::LLVMStructType::getLiteral( eltTy.getContext(), SmallVector(totalElem / threadsPerWarp, typeConverter->convertType(eltTy))); - vals[{a, b}] = loader.loadMatrix(a, b, ptrs, matTy, smemTy, cSwizzleOffset); + vals[{batch, outer, inner}] = loader.loadMatrix( + batch, outer, inner, ptrs, matTy, smemTy, cSwizzleOffset); }; return load; @@ -324,27 +351,33 @@ Value loadOperand(ConversionPatternRewriter &rewriter, Location loc, SmallVector multiDimWarpId = LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order); - unsigned ceilRes = mlir::ceil(shapePerCTA[opIdx], shape[opIdx]); - Value outerWarpDim = urem(multiDimWarpId[opIdx], i32_val(ceilRes)); - unsigned warpsPerTile = std::min(warpsPerCTA[opIdx], ceilRes); + unsigned rank = shape.size(); + unsigned dimOuter = opIdx ? (rank - 1) : (rank - 2); + unsigned ceilRes = + mlir::ceil(shapePerCTA[dimOuter], shape[dimOuter]); + Value outerWarpDim = urem(multiDimWarpId[dimOuter], i32_val(ceilRes)); + unsigned warpsPerTile = std::min(warpsPerCTA[dimOuter], ceilRes); // Get the function to use to load the operand. ValueTable vals; - std::function loadFn = getLoadMatrixFn( - descTy, smemObj, dpasLayout, warpsPerTile, std::move(shape), warpId, - outerWarpDim, laneId, vals, typeConverter, rewriter, loc); + std::function loadFn = getLoadMatrixFn( + descTy, smemObj, dpasLayout, warpsPerTile, std::move(shape), + std::move(multiDimWarpId), warpId, outerWarpDim, laneId, vals, + typeConverter, rewriter, loc); // Load the operand. - int64_t numRepOuter = numReps[opIdx]; - int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0]; + int64_t numRepBatch = numReps[0]; + int64_t numRepOuter = numReps[opIdx ? 2 : 1]; + int64_t numRepK = numReps[opIdx ? 1 : 2]; - for (int m = 0; m < numRepOuter; ++m) - for (int k = 0; k < numRepK; ++k) - loadFn(m, k); + for (int b = 0; b < numRepBatch; ++b) + for (int m = 0; m < numRepOuter; ++m) + for (int k = 0; k < numRepK; ++k) + loadFn(b, m, k); // Format the values into an LLVM::Struct. - return composeValuesToDotOperandLayoutStruct(vals, numRepOuter, numRepK, - typeConverter, loc, rewriter); + return composeValuesToDotOperandLayoutStruct( + vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter); } } // namespace diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp index 93ce802a98..529e7f07a6 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp @@ -17,7 +17,7 @@ namespace { class DotOpDPASConversionHelper { public: - using ValueTable = std::map, Value>; + using ValueTable = std::map, Value>; DotOpDPASConversionHelper(DpasEncodingAttr dpasLayout, ConversionPatternRewriter &rewriter, @@ -136,9 +136,10 @@ class DotOpDPASConversionHelper { ATensorTy.getShape(), AEncoding.getOpIdx()); SmallVector repB = BDpasEncoding.getDPASRepetitions( BTensorTy.getShape(), BEncoding.getOpIdx()); - assert(repA[1] == repB[0] && "Unexpected rep for A and B operands"); - - unsigned repM = repA[0], repN = repB[1], repK = repA[1]; + assert(repA[0] == repB[0] && "A and B should have the same batch size"); + assert(repA[2] == repB[1] && "Unexpected rep for A and B operands"); + unsigned repBatch = repA[0]; + unsigned repM = repA[1], repN = repB[2], repK = repA[2]; auto dpasType = DPASAnalysis::getDPASType(op); auto dpasEncoding = cast(DTensorTy.getEncoding()); @@ -146,13 +147,13 @@ class DotOpDPASConversionHelper { std::tie(dTy, cTy, aTy, bTy) = getDPASOperandsType(dpasType, op.getContext(), dpasEncoding); ValueTable ha = getValuesFromDotOperandLayoutStruct( - loadedA, repM, repK, + loadedA, repBatch, repM, repK, typeConverter->convertType(ATensorTy.getElementType()), aTy, 0); ValueTable hb = getValuesFromDotOperandLayoutStruct( - loadedB, repN, repK, + loadedB, repBatch, repN, repK, typeConverter->convertType(BTensorTy.getElementType()), bTy, 1); ValueTable fc = getValuesFromDotOperandLayoutStruct( - loadedC, repM, repN, + loadedC, repBatch, repM, repN, typeConverter->convertType(CTensorTy.getElementType()), cTy, 2); Type resElemTy = DTensorTy.getElementType(); @@ -166,16 +167,17 @@ class DotOpDPASConversionHelper { "A and B precision enumerators do not match"); LLVM_DEBUG({ + llvm::dbgs() << "repBatch = " << repBatch << "\n"; llvm::dbgs() << "repM = " << repM << "\n"; llvm::dbgs() << "repK = " << repK << "\n"; llvm::dbgs() << "repN = " << repN << "\n"; llvm::dbgs() << "fc.size()= " << fc.size() << "\n"; }); - auto generateDPASOp = [&](unsigned m, unsigned n, unsigned k) { - Value valA = ha.at({m, k}); - Value valB = hb.at({n, k}); - Value valc = fc.at({m, n}); + auto generateDPASOp = [&](unsigned b, unsigned m, unsigned n, unsigned k) { + Value valA = ha.at({b, m, k}); + Value valB = hb.at({b, n, k}); + Value valc = fc.at({b, m, n}); TritonGEN::PrecisionTypeAttr pA = TritonGEN::PrecisionTypeAttr::get(A.getContext(), APrecision); @@ -183,21 +185,23 @@ class DotOpDPASConversionHelper { TritonGEN::PrecisionTypeAttr::get(B.getContext(), BPrecision); auto RC = IntegerAttr::get(rewriter.getIntegerType(32), dpasEncoding.getRepeatCount()); - fc.at({m, n}) = rewriter.create( + fc.at({b, m, n}) = rewriter.create( loc, dTy, valc, valA, valB, pA, pB, RC); }; ArrayRef repCluster = dpasEncoding.getRepCluster(); - for (int k = 0; k < repK; ++k) - for (int m = 0; m < repM; ++m) - for (int n = 0; n < repN; ++n) - for (int repRow = 0; repRow < repCluster[0]; ++repRow) - for (int repCol = 0; repCol < repCluster[1]; ++repCol) - generateDPASOp(m * repCluster[0] + repRow, - n * repCluster[1] + repCol, k); - - Value res = - composeValuesToDotOperandLayoutStruct(fc, repM, repN, resElemTy); + unsigned rank = repCluster.size(); + for (int b = 0; b < repBatch; ++b) + for (int k = 0; k < repK; ++k) + for (int m = 0; m < repM; ++m) + for (int n = 0; n < repN; ++n) + for (int repRow = 0; repRow < repCluster[rank - 2]; ++repRow) + for (int repCol = 0; repCol < repCluster[rank - 1]; ++repCol) + generateDPASOp(b, m * repCluster[rank - 2] + repRow, + n * repCluster[rank - 1] + repCol, k); + + Value res = composeValuesToDotOperandLayoutStruct(fc, repBatch, repM, repN, + resElemTy); rewriter.replaceOp(op, res); @@ -229,21 +233,25 @@ class DotOpDPASConversionHelper { } Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, - int64_t dim0, int64_t dim1, + int64_t dimBatch, int64_t dimRow, + int64_t dimCol, Type elemTy) const { ArrayRef repCluster = dpasLayout.getRepCluster(); + size_t rank = repCluster.size(); std::vector elems; - for (int m = 0; m < dim0; ++m) { - for (int k = 0; k < dim1; ++k) { - for (int repRow = 0; repRow < repCluster[0]; ++repRow) { - for (int repCol = 0; repCol < repCluster[1]; ++repCol) { - Value matVal = vals.at( - {m * repCluster[0] + repRow, k * repCluster[1] + repCol}); - VectorType vecType = cast(matVal.getType()); - Type valTy = vecType.getElementType(); - for (int i = 0; i < vecType.getNumElements(); ++i) { - Value val = extract_element(valTy, matVal, i32_val(i)); - elems.push_back(val); + for (unsigned b = 0; b < dimBatch; ++b) { + for (int m = 0; m < dimRow; ++m) { + for (int k = 0; k < dimCol; ++k) { + for (int repRow = 0; repRow < repCluster[rank - 2]; ++repRow) { + for (int repCol = 0; repCol < repCluster[rank - 1]; ++repCol) { + Value matVal = vals.at({b, m * repCluster[rank - 2] + repRow, + k * repCluster[rank - 1] + repCol}); + VectorType vecType = cast(matVal.getType()); + Type valTy = vecType.getElementType(); + for (int i = 0; i < vecType.getNumElements(); ++i) { + Value val = extract_element(valTy, matVal, i32_val(i)); + elems.push_back(val); + } } } } @@ -258,29 +266,31 @@ class DotOpDPASConversionHelper { return packLLElements(loc, typeConverter, elems, rewriter, structTy); } - ValueTable getValuesFromDotOperandLayoutStruct(Value val, int64_t outer, - int64_t inner, Type elemTy, + ValueTable getValuesFromDotOperandLayoutStruct(Value val, int64_t batch, + int64_t outer, int64_t inner, + Type elemTy, Type dotOperandType, uint32_t opIdx) const { SmallVector elems = unpackLLElements(loc, val, rewriter); ArrayRef repCluster = dpasLayout.getRepCluster(); + size_t rank = repCluster.size(); unsigned repClusterOuter = 0u; unsigned repClusterInner = 0u; switch (opIdx) { case 0: // operand A - repClusterOuter = repCluster[0]; + repClusterOuter = repCluster[rank - 2]; repClusterInner = 1; break; case 1: // operand B repClusterInner = 1; - repClusterOuter = repCluster[1]; + repClusterOuter = repCluster[rank - 1]; break; case 2: // operand C - repClusterOuter = repCluster[0]; - repClusterInner = repCluster[1]; + repClusterOuter = repCluster[rank - 2]; + repClusterInner = repCluster[rank - 1]; break; default: assert(false && "invalid operand type in lowering"); @@ -289,23 +299,26 @@ class DotOpDPASConversionHelper { size_t totalElems = elems.size(); size_t numElemsPerOperand = - totalElems / ((outer * inner) * (repClusterOuter * repClusterInner)); + totalElems / + ((batch * outer * inner) * (repClusterOuter * repClusterInner)); VectorType dotOpTy = vec_ty(elemTy, numElemsPerOperand); int offset = 0; ValueTable vals; - for (int i = 0; i < outer; ++i) { - for (int j = 0; j < inner; ++j) { - for (int repOuter = 0; repOuter < repClusterOuter; ++repOuter) { - for (int repInner = 0; repInner < repClusterInner; ++repInner) { - Value matVal = rewriter.create(loc, dotOpTy); - for (int k = 0; k < numElemsPerOperand; ++k) { - matVal = - insert_element(dotOpTy, matVal, elems[offset++], i32_val(k)); + for (unsigned b = 0; b < batch; ++b) { + for (int i = 0; i < outer; ++i) { + for (int j = 0; j < inner; ++j) { + for (int repOuter = 0; repOuter < repClusterOuter; ++repOuter) { + for (int repInner = 0; repInner < repClusterInner; ++repInner) { + Value matVal = rewriter.create(loc, dotOpTy); + for (int k = 0; k < numElemsPerOperand; ++k) { + matVal = insert_element(dotOpTy, matVal, elems[offset++], + i32_val(k)); + } + vals[{b, i * repClusterOuter + repOuter, + j * repClusterInner + repInner}] = + bitcast(matVal, dotOperandType); } - vals[{i * repClusterOuter + repOuter, - j * repClusterInner + repInner}] = - bitcast(matVal, dotOperandType); } } } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 8d7fee8e38..1e252a6e82 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -511,8 +511,11 @@ struct LoadOpConversion DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value(); auto dotOrder = dotLayout.getThreadOrder(); - const bool valueRowMajor = (dotOrder[0] == 1 && dotOrder[1] == 0); - assert((valueRowMajor || (dotOrder[0] == 0 && dotOrder[1] == 1)) && + size_t rank = dotOrder.size(); + const bool valueRowMajor = + (dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0); + assert((valueRowMajor || + (dotOrder[rank - 2] == 0 && dotOrder[rank - 1] == 1)) && "Only row_major or column_major is allowed"); const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; @@ -584,27 +587,32 @@ struct LoadOpConversion auto repCluster = dpasLayout.getRepCluster(); SmallVector warpShape = isOperandA ? dpasLayout.getShapeA() : dpasLayout.getShapeB(); + + unsigned dimOuter = opIdx ? rank - 1 : rank - 2; + unsigned dimInner = opIdx ? rank - 2 : rank - 1; unsigned outerDimRequiredWarpNum = - mlir::ceil(tensorShape[opIdx], warpShape[opIdx]); + mlir::ceil(tensorShape[dimOuter], warpShape[dimOuter]); unsigned outerDimWarpNum = - std::min(warpsPerCTA[opIdx], outerDimRequiredWarpNum); + std::min(warpsPerCTA[dimOuter], outerDimRequiredWarpNum); Value outerDimWarpId = - urem(multiDimWarpId[opIdx], i32_val(outerDimWarpNum)); + urem(multiDimWarpId[dimOuter], i32_val(outerDimWarpNum)); auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX, offsetBaseY] = getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); - unsigned tileWidth = elemsPerDPASInst[dotOrder[0]]; - unsigned tileHeight = elemsPerDPASInst[dotOrder[1]]; + unsigned tileWidth = elemsPerDPASInst[dotOrder[rank - 2]]; + unsigned tileHeight = elemsPerDPASInst[dotOrder[rank - 1]]; unsigned vBlocks = 1; unsigned numOperandsOuterDimPerLoad = 1; unsigned numOperandsInnerDimPerLoad = 1; unsigned numOperandsPer2DLoadM, numOperandsPer2DloadN; if (!isTransposeRequired) { - numOperandsPer2DLoadM = isOperandA ? repCluster[opIdx] : numReps[!opIdx]; - numOperandsPer2DloadN = isOperandA ? numReps[!opIdx] : repCluster[opIdx]; + numOperandsPer2DLoadM = + isOperandA ? repCluster[dimOuter] : numReps[opIdx ? 1 : 2]; + numOperandsPer2DloadN = + isOperandA ? numReps[opIdx ? 1 : 2] : repCluster[dimOuter]; } else { if (isOperandA) return failure(); @@ -620,7 +628,7 @@ struct LoadOpConversion // Note: the tileHeight and numOperandsPer2DLoadM are the column size // now. numOperandsPer2DLoadM = - (threadsPerWarp <= tileHeight) ? repCluster[1] : 1; + (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1; // The transpose 2d load only support 1 operand per inst on column. // (vBlocks = 1) numOperandsPer2DloadN = 1; @@ -647,7 +655,7 @@ struct LoadOpConversion std::swap(numOperandsOuterDimPerLoad, numOperandsInnerDimPerLoad); unsigned numLoadPerOutRepCluster = - mlir::ceil(repCluster[opIdx], numOperandsOuterDimPerLoad); + mlir::ceil(repCluster[dimOuter], numOperandsOuterDimPerLoad); unsigned numValuesPerLoad = packedElemsPerLanePerDPASInst * numOperandsOuterDimPerLoad * @@ -656,13 +664,14 @@ struct LoadOpConversion LLVM::getFixedVectorType(loadResultElemType, numValuesPerLoad); // The stride for the replicates. - unsigned repOuterStride = warpShape[opIdx] * outerDimWarpNum; - unsigned repStride = elemsPerDPASInst[opIdx] * numOperandsOuterDimPerLoad; - unsigned warpOuterStride = warpShape[opIdx]; - unsigned repKStride = elemsPerDPASInst[opIdx == 0 ? 1 : 0]; + unsigned repOuterStride = warpShape[dimOuter] * outerDimWarpNum; + unsigned repStride = + elemsPerDPASInst[dimOuter] * numOperandsOuterDimPerLoad; + unsigned warpOuterStride = warpShape[dimOuter]; + unsigned repKStride = elemsPerDPASInst[dimInner]; - unsigned numRepOuter = numReps[opIdx]; - unsigned numRepInner = numReps[!opIdx]; + unsigned numRepOuter = numReps[opIdx ? 2 : 1]; + unsigned numRepInner = numReps[opIdx ? 1 : 2]; Value pitch; if (memoryRowMajor) { @@ -1010,6 +1019,7 @@ struct StoreOpConversion unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); Value elemSizeInBytes = i32_val(elemSizeInBits / 8); const ArrayRef tensorShape = tensorType.getShape(); + size_t rank = tensorShape.size(); unsigned numElems = getTotalElemsPerThread(tensorType); SmallVector elemsPerInstr = dpasLayout.getDPASInstShapeC(); const SmallVector warpsPerCTA = dpasLayout.getWarpsPerCTA(); @@ -1047,21 +1057,23 @@ struct StoreOpConversion // A warp stride for the replicates. SmallVector repClusterShape = dpasLayout.getShapeC(); unsigned outerDimWarpNum = std::min( - warpsPerCTA[0], - mlir::ceil(tensorShape[0], repClusterShape[0])); + warpsPerCTA[rank - 2], + mlir::ceil(tensorShape[rank - 2], repClusterShape[rank - 2])); unsigned innerDimWarpNum = std::min( - warpsPerCTA[1], - mlir::ceil(tensorShape[1], repClusterShape[1])); - Value outerDimWarpId = urem(multiDimWarpId[0], i32_val(outerDimWarpNum)); - Value innerDimWarpId = urem(multiDimWarpId[1], i32_val(innerDimWarpNum)); - int64_t numRepOuter = numReps[0]; - int64_t numRepInner = numReps[1]; + warpsPerCTA[rank - 1], + mlir::ceil(tensorShape[rank - 1], repClusterShape[rank - 1])); + Value outerDimWarpId = + urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum)); + Value innerDimWarpId = + urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum)); + int64_t numRepOuter = numReps[1]; + int64_t numRepInner = numReps[2]; std::array replicaStride = { - outerDimWarpNum * repClusterShape[0], - innerDimWarpNum * repClusterShape[1]}; - std::array warpStride = {repClusterShape[0], - repClusterShape[1]}; + outerDimWarpNum * repClusterShape[rank - 2], + innerDimWarpNum * repClusterShape[rank - 1]}; + std::array warpStride = {repClusterShape[rank - 2], + repClusterShape[rank - 1]}; Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0])); Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1])); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h index 85e5f72b0f..2160b8f17d 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h @@ -163,8 +163,10 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout, SmallVector instShapeC = dpasLayout.getDPASInstShapeC(); SmallVector sizePerThreads = getSizePerThread(dpasLayout); ArrayRef repCluster = dpasLayout.getRepCluster(); - SmallVector sizePerDPASInst = {sizePerThreads[0] / repCluster[0], - sizePerThreads[1] / repCluster[1]}; + size_t rank = repCluster.size(); + SmallVector sizePerDPASInst = { + sizePerThreads[rank - 2] / repCluster[rank - 2], + sizePerThreads[rank - 1] / repCluster[rank - 1]}; unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1]; unsigned colsPerElem = 1; @@ -175,15 +177,21 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout, for (unsigned elemId = 0; elemId < elemNumberPerRep; ++elemId) { // Follows the C++ order for the dpas layout. SmallVector repOffset = { - (repId / repCluster[1]) * instShapeC[0], - (repId % repCluster[1]) * instShapeC[1]}; + (repId / repCluster[rank - 1]) * instShapeC[0], + (repId % repCluster[rank - 1]) * instShapeC[1]}; SmallVector elemOffset = { (elemId / sizePerDPASInst[1]) * rowsPerElem, (elemId % sizePerDPASInst[1]) * colsPerElem}; - offsets.push_back({repOffset[0] + elemOffset[0] + ctaOffsetX, - repOffset[1] + elemOffset[1] + ctaOffsetY}); + if (rank == 3) + offsets.push_back({0, repOffset[0] + elemOffset[0] + ctaOffsetX, + repOffset[1] + elemOffset[1] + ctaOffsetY}); + else { + assert((rank == 2) && "unexpected rank number for Dpas layout"); + offsets.push_back({repOffset[0] + elemOffset[0] + ctaOffsetX, + repOffset[1] + elemOffset[1] + ctaOffsetY}); + } } } } @@ -216,6 +224,7 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout, unsigned executionSize = dpasLayout.getExecutionSize(); unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); + unsigned rank = shape.size(); unsigned numRowsPerPackedValue = 0u, numColsPerPackedValue = 0u; unsigned numColsPerLaneForPackedValue = 0u, numOpsPerPackedValue = 0u; switch (opIdx) { @@ -225,7 +234,7 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout, SmallVector shapeA = dpasLayout.getShapeA(); // Unlike the operand B, to pack the value to i16 for scalar bit width <=16. numOpsPerPackedValue = opsPerChannel == 4 ? 2 : 1; - unsigned packedColNum = shapeA[1] / numOpsPerPackedValue; + unsigned packedColNum = shapeA[rank - 1] / numOpsPerPackedValue; // Each value name represent multiple rows if warpSize > packedColNum numRowsPerPackedValue = mlir::ceil(warpSize, packedColNum); numColsPerPackedValue = std::min(warpSize, packedColNum); @@ -245,13 +254,13 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout, "numElemPerInstPerRowPerThread should not be zero"); SmallVector shapePerCTATile = getShapePerCTATile(dotLayout); - int64_t numRepOuter = numReps[opIdx]; - int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0]; + int64_t numRepOuter = numReps[opIdx ? 2 : 1]; + int64_t numRepK = numReps[opIdx ? 1 : 2]; ArrayRef repCluster = dpasLayout.getRepCluster(); - unsigned repClusterSize = repCluster[opIdx]; + unsigned repClusterSize = repCluster[opIdx ? rank - 1 : rank - 2]; - for (unsigned dimOuter = 0; dimOuter < numRepOuter; ++dimOuter) + for (unsigned repOuter = 0; repOuter < numRepOuter; ++repOuter) for (unsigned k = 0; k < numRepK; ++k) for (unsigned rep = 0; rep < repClusterSize; ++rep) { for (unsigned elemId = 0; elemId < numElemPerInstPerThread; ++elemId) { @@ -261,9 +270,9 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout, (opIdx == 0) ? elemId % numOpsPerPackedValue : 0; unsigned packedElemId = elemId / numOpsPerPackedValue; unsigned repRowIndex = - shapePerCTATile[0] * (opIdx == 0 ? dimOuter : k); + shapePerCTATile[rank - 2] * (opIdx == 0 ? repOuter : k); unsigned repColIndex = - shapePerCTATile[1] * (opIdx == 0 ? k : dimOuter); + shapePerCTATile[rank - 1] * (opIdx == 0 ? k : repOuter); unsigned repClusterRowIndex = opIdx == 0 ? rep * instShape[0] : 0; unsigned repClusterColIndex = opIdx == 0 ? 0 : rep * instShape[1]; unsigned packedElemRowIndex = @@ -272,10 +281,19 @@ emitOffsetForDotOpLayout(const DotOperandEncodingAttr &dotLayout, unsigned packedElemColIndex = (packedElemId % numColsPerLaneForPackedValue) * numColsPerPackedValue; - offsets.push_back({repRowIndex + repClusterRowIndex + - packedElemRowIndex + opsRowIndex, - repColIndex + repClusterColIndex + - packedElemColIndex + opsColIndex}); + if (rank == 3) + offsets.push_back({0, + repRowIndex + repClusterRowIndex + + packedElemRowIndex + opsRowIndex, + repColIndex + repClusterColIndex + + packedElemColIndex + opsColIndex}); + else { + assert((rank == 2) && "unexpected rank number for Dot layout"); + offsets.push_back({repRowIndex + repClusterRowIndex + + packedElemRowIndex + opsRowIndex, + repColIndex + repClusterColIndex + + packedElemColIndex + opsColIndex}); + } } } @@ -288,9 +306,10 @@ emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout, ArrayRef shape = type.getShape(); SmallVector> offsets; SmallVector shapePerCTA = getShapePerCTATile(dpasLayout); + size_t rank = shape.size(); - for (unsigned i = 0; i < shape[0]; i += shapePerCTA[0]) { - for (unsigned j = 0; j < shape[1]; j += shapePerCTA[1]) { + for (unsigned i = 0; i < shape[rank - 2]; i += shapePerCTA[rank - 2]) { + for (unsigned j = 0; j < shape[rank - 1]; j += shapePerCTA[rank - 1]) { emitOffsetForDpasLayoutPerCTA(dpasLayout, offsets, i, j); } } @@ -329,13 +348,17 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter, SmallVector multiDimWarpId = mlir::LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order); + size_t rank = warpShape.size(); + assert(rank == shapePerCTA.size() && "Rank mismatch"); Value warpIndex = - (opIdx == 0) - ? urem(multiDimWarpId[0], - i32_val(mlir::ceil(shapePerCTA[0], warpShape[0]))) - : urem(multiDimWarpId[1], - i32_val(mlir::ceil(shapePerCTA[1], warpShape[1]))); - Value warpOffset = mul(warpIndex, i32_val(warpShape[opIdx])); + (opIdx == 0) ? urem(multiDimWarpId[rank - 2], + i32_val(mlir::ceil(shapePerCTA[rank - 2], + warpShape[rank - 2]))) + : urem(multiDimWarpId[rank - 1], + i32_val(mlir::ceil(shapePerCTA[rank - 1], + warpShape[rank - 1]))); + Value warpOffset = + mul(warpIndex, i32_val(warpShape[opIdx ? rank - 1 : rank - 2])); // Compute the 2-dim coordinates of the first element in the warp operated // own by this thread. @@ -351,7 +374,7 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter, // Unlike the operand B, to pack the value to i16 for scalar bit width // <=16. unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1; - unsigned packedColNum = shapeA[1] / packedOpsPerLane; + unsigned packedColNum = shapeA[rank - 1] / packedOpsPerLane; if (warpSize < packedColNum) llvm::report_fatal_error( "DpasEncodingAttr sub-group size could not " @@ -371,12 +394,18 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter, laneRowIndex = mul(laneRowIndex, i32_val(opsPerChannel)); laneColIndex = urem(laneId, i32_val(executionSize)); } break; + default: { + llvm::report_fatal_error("Only support opIdx 1 or 0 for DotOpLayout."); + } } - auto multiDimBase = - (opIdx == 0) - ? SmallVector{add(laneRowIndex, warpOffset), laneColIndex} - : SmallVector{laneRowIndex, add(laneColIndex, warpOffset)}; + SmallVector multiDimBase(rank); + if (rank == 3) + multiDimBase[0] = multiDimWarpId[0]; + multiDimBase[rank - 2] = + (opIdx == 0) ? add(laneRowIndex, warpOffset) : laneRowIndex; + multiDimBase[rank - 1] = + (opIdx == 0) ? laneColIndex : add(laneColIndex, warpOffset); return multiDimBase; } @@ -390,6 +419,7 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter, Value warpId = udiv(threadId, warpSize); Value laneId = urem(threadId, warpSize); + size_t rank = type.getShape().size(); auto warpsPerCTA = dpasLayout.getWarpsPerCTA(); ArrayRef shape = type.getShape(); @@ -400,19 +430,25 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter, // Compute the 2-dim coordinates of the warp containing the tensor element // operated on by this thread. SmallVector warpShape = dpasLayout.getShapeC(); - Value rowWarpId = urem(multiDimWarpId[0], - i32_val(mlir::ceil(shape[0], warpShape[0]))); - Value colWarpId = urem(multiDimWarpId[1], - i32_val(mlir::ceil(shape[1], warpShape[1]))); - Value rowWarpOffset = mul(rowWarpId, i32_val(warpShape[0])); - Value colWarpOffset = mul(colWarpId, i32_val(warpShape[1])); + Value rowWarpId = + urem(multiDimWarpId[rank - 2], + i32_val(mlir::ceil(shape[rank - 2], warpShape[rank - 2]))); + Value colWarpId = + urem(multiDimWarpId[rank - 1], + i32_val(mlir::ceil(shape[rank - 1], warpShape[rank - 1]))); + Value rowWarpOffset = mul(rowWarpId, i32_val(warpShape[rank - 2])); + Value colWarpOffset = mul(colWarpId, i32_val(warpShape[rank - 1])); // Compute the 2-dim coordinates of the first element in the warp operated // on by this thread. SmallVector threadsPerWarp = getThreadsPerWarp(dpasLayout); - SmallVector multiDimBase = { - add(udiv(laneId, i32_val(threadsPerWarp[1])), rowWarpOffset), - add(urem(laneId, i32_val(threadsPerWarp[1])), colWarpOffset)}; + SmallVector multiDimBase(rank); + if (rank == 3) + multiDimBase[0] = multiDimWarpId[0]; + multiDimBase[rank - 2] = + add(udiv(laneId, i32_val(threadsPerWarp[rank - 1])), rowWarpOffset); + multiDimBase[rank - 1] = + add(urem(laneId, i32_val(threadsPerWarp[rank - 1])), colWarpOffset); return multiDimBase; } diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index b36add481f..3e636d5bae 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -69,29 +69,39 @@ SmallVector getWarpsPerTile(DotOp dotOp, if (isa(op) && (op != dotOp)) return {numWarps, 1}; - SmallVector ret{1, 1}; - SmallVector shapePerWarp{dpasCap.repeatCount, dpasCap.executionSize}; + size_t rank = shape.size(); + SmallVector ret(rank, 1); + + if (rank == 3) { + int batchWarp = numWarps; + while (batchWarp > shape[0]) + batchWarp /= 2; + ret[0] = batchWarp; + numWarps /= batchWarp; + } // Try to find a proper tiling shape for the dot operation. // It doubles the warp number in col or row in each time based on column to // width ratio. // By this, we can minimize the duplication of the dot operands A and B. + SmallVector shapePerWarp{dpasCap.repeatCount, dpasCap.executionSize}; uint32_t rowColRatio = ceil(dpasCap.repeatCount, dpasCap.executionSize); uint32_t colRowRatio = ceil(dpasCap.executionSize, dpasCap.repeatCount); + int rowDim = rank - 2, colDim = rank - 1; do { - if (ret[0] * ret[1] >= numWarps) + if (ret[rowDim] * ret[colDim] >= numWarps) break; - if (shape[0] / (shapePerWarp[0] * colRowRatio) / ret[0] >= - shape[1] / (shapePerWarp[1] * rowColRatio) / ret[1]) { - if (ret[0] < shape[0] / shapePerWarp[0]) - ret[0] *= 2; + if (shape[rowDim] / (shapePerWarp[0] * colRowRatio) / ret[rowDim] >= + shape[colDim] / (shapePerWarp[1] * rowColRatio) / ret[colDim]) { + if (ret[rowDim] < shape[rowDim] / shapePerWarp[0]) + ret[rowDim] *= 2; else - ret[1] *= 2; + ret[colDim] *= 2; } else { - ret[1] *= 2; + ret[colDim] *= 2; } } while (true); @@ -121,6 +131,7 @@ class BlockedToDPAS : public RewritePattern { // Create DPAS encoding for the given number of warps ArrayRef retShape = oldRetType.getShape(); + size_t rank = retShape.size(); ModuleOp mod = funcOp->getParentOfType(); unsigned numWarps = TritonGPUDialect::getNumWarps(mod); @@ -145,11 +156,12 @@ class BlockedToDPAS : public RewritePattern { unsigned opsPerChan = dpasCap.opsChanBitWidths / dpasElemBitWidths; SmallVector warpsPerTile = getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); + SmallVector repCluster(rank, 1); unsigned threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod); auto dpasEnc = intel::DpasEncodingAttr::get( oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, - dpasCap.executionSize, opsPerChan, warpsPerTile, {1, 1}, + dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, threadsPerWarp); if (dpasCap.executionSize == 16 /* PVC */) { @@ -160,7 +172,7 @@ class BlockedToDPAS : public RewritePattern { SmallVector repA = dpasEnc.getDPASRepetitions(oldAType.getShape(), 0); unsigned repClusterDimM = - std::min(maxRepClusterM, static_cast(repA[0])); + std::min(maxRepClusterM, static_cast(repA[1])); unsigned maxRepClusterN = PVC_2D_LOAD_MAXIMUM_BYTES_OF_COLS / @@ -168,12 +180,14 @@ class BlockedToDPAS : public RewritePattern { SmallVector repB = dpasEnc.getDPASRepetitions(oldBType.getShape(), 1); unsigned repClusterDimN = - std::min(maxRepClusterN, static_cast(repB[1])); + std::min(maxRepClusterN, static_cast(repB[2])); + repCluster[rank - 2] = repClusterDimM; + repCluster[rank - 1] = repClusterDimN; dpasEnc = intel::DpasEncodingAttr::get( oldRetType.getContext(), dpasCap.repeatCount, dpasCap.systolicDepth, - dpasCap.executionSize, opsPerChan, warpsPerTile, - {repClusterDimM, repClusterDimN}, threadsPerWarp); + dpasCap.executionSize, opsPerChan, warpsPerTile, repCluster, + threadsPerWarp); } RankedTensorType newRetType = diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp index 98fae12c81..54e82ac9c2 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/ReduceDataDuplication.cpp @@ -42,6 +42,8 @@ class TritonIntelGPUReduceDataDuplicationPass auto srcType = cast(cvtOp.getSrc().getType()); auto dstType = cast(cvtOp.getType()); auto srcEncoding = srcType.getEncoding(); + auto srcOrder = triton::gpu::getOrder(srcEncoding); + auto rank = srcOrder.size(); if (isa(srcEncoding)) return; auto dstDotOp = @@ -52,14 +54,14 @@ class TritonIntelGPUReduceDataDuplicationPass dyn_cast(srcEncoding)) { if (srcMmaEncoding.getVersionMajor() != 2 || - (srcMmaEncoding.getWarpsPerCTA()[1] == 1 && + (srcMmaEncoding.getWarpsPerCTA()[rank - 1] == 1 && dstDotOp.getParent() == srcMmaEncoding)) return; } if (auto srcMfmaEncoding = dyn_cast(srcEncoding)) { - if (srcMfmaEncoding.getWarpsPerCTA()[1] == 1 && + if (srcMfmaEncoding.getWarpsPerCTA()[rank - 1] == 1 && srcMfmaEncoding.getIsTransposed() && dstDotOp.getParent() == srcMfmaEncoding) return; @@ -69,17 +71,15 @@ class TritonIntelGPUReduceDataDuplicationPass unsigned opIdx = dstDotOp.getOpIdx(); if ((opIdx == 0 /* Operand A */ && dstDotOp.getParent() == srcDpasEncoding && - srcDpasEncoding.getWarpsPerCTA()[1] == + srcDpasEncoding.getWarpsPerCTA()[rank - 1] == 1 /* No parallel on N dim */) || (opIdx == 1 /* Operand B */ && dstDotOp.getParent() == srcDpasEncoding && - srcDpasEncoding.getWarpsPerCTA()[0] == + srcDpasEncoding.getWarpsPerCTA()[rank - 2] == 1 /* No parallel on M dim */)) /* The destination dot layout has no duplication. */ return; } - auto srcOrder = triton::gpu::getOrder(srcEncoding); - auto rank = srcOrder.size(); SmallVector sharedOrder; if (rank == 3) { // add all elements except the element that is zero