Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/triton/Tools/Sys/GetEnv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
"TRITON_F32_DEFAULT",
"TRITON_PREFER_TMEM_16x256_LAYOUT",
"TRITON_ENABLE_EXPERIMENTAL_CONSAN",
"TRITON_INTEL_2DBLOCK_ASSERT",
"TRITON_INTEL_AGGRESSIVE_DPAS_REUSE",
"TRITON_INTEL_ENABLE_BLOCK_IO_ALL_LAYOUTS",
"TRITON_INTEL_ENABLE_DPAS_FOR_WARP_SIZE_32",
Expand Down
55 changes: 55 additions & 0 deletions python/test/unit/intel/block_load_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch
import triton

import ctypes
import sys


def run_load_ir(temp_file, elem_size, *args):
out_type = f"i{int(elem_size) * 4}"
ir = f"""
module attributes {{
ttg.target = "xpu",
"ttg.num-warps" = 32 : i32,
"ttg.num-ctas" = 1 : i32,
"ttg.threads-per-warp" = 16 : i32
}} {{
tt.func @dyn_block(
%iptr : i64, %base_width : i32,
%base_height : i32, %base_pitch : i32,
%x : i32, %y : i32) {{
%p0 = llvm.inttoptr %iptr : i64 to !llvm.ptr

%v = triton_gen.2Dblockload %p0, %base_width, %base_height,
%base_pitch, %x, %y
{{ elem_size_in_bits = {elem_size}, tile_width = 8, tile_height = 8,
v_blocks = 1, transpose = false,
vnni_transform = false, cache_control = Default }}
: (!llvm.ptr, i32, i32, i32, i32, i32)
-> vector<1x{out_type}>

// prevent GluonInline
%v_cast = llvm.bitcast %v : vector<1x{out_type}> to {out_type}
llvm.inline_asm has_side_effects asm_dialect = att
"", "r" %v_cast : ({out_type}) -> ()

tt.return
}}
}}
"""

with open(temp_file, "w", encoding="utf-8") as f:
f.write(ir)

kernel = triton.compile(temp_file)

a = torch.zeros((256, 64), dtype=torch.float32, device="xpu")

addr = ctypes.c_int64(a.data_ptr()).value

kernel[(1, 1, 1)](addr, *map(int, args), 0)


if __name__ == "__main__":
fn = globals()[sys.argv[1]]
fn(*sys.argv[2:])
47 changes: 47 additions & 0 deletions python/test/unit/intel/test_block_load.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import pytest
import torch

import os
import signal
import subprocess
import sys
import pathlib
from functools import partial

Expand Down Expand Up @@ -207,3 +212,45 @@ def triton_mm(X, Y, b=None, transpose_x=False, transpose_y=False):
result_tor = fn_tor()
result_tri = fn_tri()
torch.testing.assert_close(result_tri, result_tor, atol=1e-2, rtol=1e-3)


@pytest.mark.parametrize("elem_size, width, height, pitch, x",
[[8, 16777216, 64, 16777216, 0], # width <= 24 bits
[8, 32, 64, 128, 0], # width >= 64
[8, 66, 64, 128, 0], # width % max(4,elemSize) == 0
[8, 128, 16777216, 128, 0], # height <= 24 bits
[8, 128, 64, 16777216, 0], # pitch <= 24 bits
[8, 128, 64, 32, 0], # pitch >= 64
[8, 128, 64, 70, 0], # pitch % 16 == 0
[8, 128, 64, 120, 0], # pitch >= width
[8, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 8-bit)
[16, 128, 64, 128, 1], # x*elemSize % 4 == 0 (alignment for 16-bit)
])
@pytest.mark.skipif(not is_xpu(), reason="Block load tests are specific to the XPU backend")
@pytest.mark.xfail(
not (torch.xpu.get_device_capability()['has_subgroup_2d_block_io']
and torch.xpu.get_device_capability()['has_subgroup_matrix_multiply_accumulate']),
reason="Block loads and/or DPAS not supported on this architecture", run=False)
def test_block_load_asserts(elem_size, width, height, pitch, x, monkeypatch, tmp_path: pathlib.Path):
monkeypatch.setenv("TRITON_INTEL_2DBLOCK_ASSERT", "1")

dir_path = os.path.dirname(os.path.realpath(__file__))
helper_path = os.path.join(dir_path, "block_load_helper.py")

temp_file = tmp_path / "test_block_load_asserts.ttgir"

proc = subprocess.run(
[
sys.executable, helper_path, "run_load_ir",
str(temp_file),
str(elem_size),
str(width),
str(height),
str(pitch),
str(x)
],
capture_output=True,
)

rc = proc.returncode
assert rc == -signal.SIGABRT
196 changes: 0 additions & 196 deletions test/Triton/Intel/FuseReshape/fuse-reshape.mlir

This file was deleted.

Loading
Loading