Skip to content

Commit 27f8636

Browse files
authored
Merge branch 'main' into use-llvm-func-attrs
2 parents f800e9f + 7551a90 commit 27f8636

File tree

10 files changed

+160
-130
lines changed

10 files changed

+160
-130
lines changed

python/setup.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,13 @@ def find_visual_studio(version_ranges):
119119
for version_range in version_ranges:
120120
command = [
121121
str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
122-
"-property", "installationPath", "-prerelease"
122+
"-products", "*", "-property", "installationPath", "-prerelease"
123123
]
124124

125125
try:
126126
output = subprocess.check_output(command, text=True).strip()
127127
if output:
128-
return output
128+
return output.split("\n")[0]
129129
except subprocess.CalledProcessError:
130130
continue
131131

@@ -146,6 +146,13 @@ def set_env_vars(vs_path, arch="x64"):
146146
os.environ[var] = value
147147

148148

149+
def initialize_visual_studio_env(version_ranges, arch="x64"):
150+
vs_path = find_visual_studio(version_ranges)
151+
if not vs_path:
152+
raise EnvironmentError("Visual Studio not found in specified version ranges.")
153+
set_env_vars(vs_path, arch)
154+
155+
149156
# Taken from https://github.com/pytorch/pytorch/blob/master/tools/setup_helpers/env.py
150157
def check_env_flag(name: str, default: str = "") -> bool:
151158
return os.getenv(name, default).upper() in ["ON", "1", "YES", "TRUE", "Y"]
@@ -447,10 +454,7 @@ def build_extension(self, ext):
447454
lit_dir = shutil.which('lit')
448455
ninja_dir = shutil.which('ninja')
449456
if platform.system() == "Windows":
450-
vs_path = find_visual_studio(["[17.0,18.0)", "[16.0,17.0)"])
451-
env = set_env_vars(vs_path)
452-
if not vs_path:
453-
raise EnvironmentError("Visual Studio 2019 or 2022 not found.")
457+
initialize_visual_studio_env(["[17.0,18.0)", "[16.0,17.0)"])
454458
# lit is used by the test suite
455459
thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()])
456460
thirdparty_cmake_args += self.get_pybind11_cmake_args()

python/triton/runtime/CLFinder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def find_visual_studio(version_ranges):
1919
for version_range in version_ranges:
2020
command = [
2121
str(vswhere), "-version", version_range, "-requires", "Microsoft.VisualStudio.Component.VC.Tools.x86.x64",
22-
"-property", "installationPath", "-prerelease"
22+
"-products", "*", "-property", "installationPath", "-prerelease"
2323
]
2424

2525
try:
2626
output = subprocess.check_output(command, text=True).strip()
2727
if output:
28-
return output
28+
return output.split("\n")[0]
2929
except subprocess.CalledProcessError:
3030
continue
3131

@@ -37,7 +37,7 @@ def set_env_vars(vs_path, arch="x64"):
3737
if not vcvarsall_path.exists():
3838
raise FileNotFoundError(f"vcvarsall.bat not found in expected path: {vcvarsall_path}")
3939

40-
command = f'call "{vcvarsall_path}" {arch} && set'
40+
command = ["call", vcvarsall_path, arch, "&&", "set"]
4141
output = subprocess.check_output(command, shell=True, text=True)
4242

4343
for line in output.splitlines():

test/TritonGEN/tritongen-invalid.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -428,19 +428,3 @@ llvm.func @matrix_2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base_hei
428428
triton_gen.2Dblockprefetch %ptr, %base_width, %base_height, %base_pitch, %x, %y {elem_size_in_bits=32, tile_width=32, tile_height=8, v_blocks=1, cache_control=Default} : (!llvm.ptr, i32, i32, i32, i32, i32)
429429
llvm.return
430430
}
431-
432-
// -----
433-
434-
llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
435-
// expected-error @+1 {{'triton_gen.simdblockread' op unsupported vector type}}
436-
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<64xi16>
437-
llvm.return
438-
}
439-
440-
// -----
441-
442-
llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val: vector<64xi16>) {
443-
// expected-error @+1 {{'triton_gen.simdblockwrite' op unsupported vector type}}
444-
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<64xi16>)
445-
llvm.return
446-
}

test/TritonGEN/tritongen-to-llvm.mlir

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,29 @@ llvm.func @triton_gen.dpas.bf16_accum(%c: vector<8xbf16>, %a : vector<8xi16>, %b
241241

242242
// CHECK: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(!llvm.ptr<3>) -> vector<2xi16> attributes {memory_effects = #llvm.memory_effects<other = none, argMem = read, inaccessibleMem = none>, no_unwind, will_return}
243243

244-
llvm.func @triton_gen.simdblockread(%ptr: !llvm.ptr<3>) {
245-
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr<3>) {
244+
llvm.func @triton_gen.sub_group_block_read(%ptr: !llvm.ptr<3>) {
245+
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<3>) {
246246
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_read_us2PU3AS3t(%arg0) {{.*}} : (!llvm.ptr<3>) -> vector<2xi16>
247-
%ret = triton_gen.simdblockread %ptr : (!llvm.ptr<3>) -> vector<2xi16>
247+
%ret = triton_gen.sub_group_block_read %ptr : !llvm.ptr<3> -> vector<2xi16>
248248
llvm.return
249249
}
250250

251251
// -----
252252

253253
// CHECK: llvm.func spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(!llvm.ptr<3>, vector<2xi16>) attributes {memory_effects = #llvm.memory_effects<other = none, argMem = readwrite, inaccessibleMem = none>, no_unwind, will_return}
254254

255-
llvm.func @triton_gen.simdblockwrite(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
256-
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
255+
llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<3>, %val : vector<2xi16>) {
256+
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: vector<2xi16>) {
257257
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_us2PU3AS3tDv2_t(%arg0, %arg1) {{.*}} : (!llvm.ptr<3>, vector<2xi16>) -> ()
258-
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr<3>, vector<2xi16>)
258+
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, vector<2xi16>
259+
llvm.return
260+
}
261+
262+
// -----
263+
264+
llvm.func @triton_gen.sub_group_block_write(%ptr: !llvm.ptr<1>, %val : i32) {
265+
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<1>, %arg1: i32) {
266+
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_uiPU3AS1jj(%arg0, %arg1) {{.*}} : (!llvm.ptr<1>, i32) -> ()
267+
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, i32
259268
llvm.return
260269
}

test/TritonGEN/tritongen.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,16 +125,16 @@ llvm.func @triton_gen.2Dblockprefetch(%ptr : !llvm.ptr, %base_width : i32, %base
125125
llvm.return
126126
}
127127

128-
llvm.func @triton_gen.simdblockread(%ptr : !llvm.ptr) {
129-
// CHECK: llvm.func @triton_gen.simdblockread(%arg0: !llvm.ptr) {
130-
// CHECK-NEXT: triton_gen.simdblockread %arg0 : (!llvm.ptr) -> vector<2xi16>
131-
triton_gen.simdblockread %ptr : (!llvm.ptr) -> vector<2xi16>
128+
llvm.func @triton_gen.sub_group_block_read(%ptr : !llvm.ptr<1>) {
129+
// CHECK: llvm.func @triton_gen.sub_group_block_read(%arg0: !llvm.ptr<1>) {
130+
// CHECK-NEXT: triton_gen.sub_group_block_read %arg0 : !llvm.ptr<1> -> vector<2xi16>
131+
triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<2xi16>
132132
llvm.return
133133
}
134134

135-
llvm.func @triton_gen.simdblockwrite(%ptr : !llvm.ptr, %val : vector<2xi16>) {
136-
// CHECK: llvm.func @triton_gen.simdblockwrite(%arg0: !llvm.ptr, %arg1: vector<2xi16>) {
137-
// CHECK-NEXT: triton_gen.simdblockwrite %arg0, %arg1 : (!llvm.ptr, vector<2xi16>)
138-
triton_gen.simdblockwrite %ptr, %val : (!llvm.ptr, vector<2xi16>)
135+
llvm.func @triton_gen.sub_group_block_write(%ptr : !llvm.ptr<3>, %val : i32) {
136+
// CHECK: llvm.func @triton_gen.sub_group_block_write(%arg0: !llvm.ptr<3>, %arg1: i32) {
137+
// CHECK-NEXT: triton_gen.sub_group_block_write %arg0, %arg1 : !llvm.ptr<3>, i32
138+
triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<3>, i32
139139
llvm.return
140140
}

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 76 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -314,46 +314,96 @@ def TritonGEN_Matrix2DBlockPrefetchOp : TritonGEN_Op<"2Dblockprefetch">,
314314
let hasVerifier = 1;
315315
}
316316

317-
def TritonGEN_SIMDBlockReadOp: TritonGEN_Op<"simdblockread">,
318-
Results<(outs FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$res)>,
319-
Arguments<(ins
320-
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr
321-
)> {
322-
323-
let summary = "simd block read";
317+
def TritonGEN_SubGroupBlockMemoryAccessElementType
318+
: AnyTypeOf<[I8, I16, I32, I64],
319+
"Valid sub-group block memory access element type">;
320+
321+
def TritonGEN_SubGroupBlockMemoryAccessType
322+
: AnyTypeOf<[TritonGEN_SubGroupBlockMemoryAccessElementType,
323+
FixedVectorOfLengthAndType<[2, 4, 8],
324+
[TritonGEN_SubGroupBlockMemoryAccessElementType]>,
325+
// Vectors of length 16 only allowed for i8 for now.
326+
FixedVectorOfLengthAndType<[16], [I8]>],
327+
"Valid sub-group block memory access type">;
328+
329+
def TritonGEN_SubGroupBlockMemoryAccessPointerType
330+
: Type<And<[LLVM_AnyPointer.predicate,
331+
Or<[CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
332+
".getAddressSpace() == " #
333+
"static_cast<unsigned>(kCrossWorkgroup)">,
334+
CPred<"::llvm::cast<::mlir::LLVM::LLVMPointerType>($_self)" #
335+
".getAddressSpace() == " #
336+
"static_cast<unsigned>(kWorkgroup)">]>]>,
337+
"LLVM pointer in local or global OpenCL address space",
338+
"::mlir::LLVM::LLVMPointerType">;
339+
340+
def TritonGEN_SubGroupBlockReadOp: TritonGEN_Op<"sub_group_block_read"> {
341+
let summary = "Sub-group block read.";
324342

325343
let description = [{
326-
The `triton_gen.simdblockread` operation performs simd block read from
327-
a start address without laneId offset. The parameters are:
328-
$ptr - the base address to read data
344+
The `triton_gen.sub_group_block_read` reads a scalar or vector for each
345+
work-item in the sub-group from pointer `ptr` as a block operation.
346+
The data is read strided, so the first value is read from:
347+
```
348+
ptr[sub_group_local_id]
349+
```
350+
and the second one is:
351+
```
352+
ptr[sub_group_local_id + sub_group_size]
353+
```
354+
etc.
355+
356+
`ptr` must be aligned to the size of the element type of `res`.
357+
358+
Example:
359+
```mlir
360+
%0 = triton_gen.sub_group_block_read %ptr : !llvm.ptr<1> -> vector<4xi32>
361+
```
329362
}];
330363

364+
let arguments = (ins
365+
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr);
366+
367+
let results = (outs TritonGEN_SubGroupBlockMemoryAccessType:$res);
368+
331369
let assemblyFormat = [{
332-
operands ` ` attr-dict `:` functional-type(operands, results)
370+
$ptr attr-dict `:` qualified(type($ptr)) `->` type($res)
333371
}];
334-
335-
let hasVerifier = 1;
336372
}
337373

338-
def TritonGEN_SIMDBlockWriteOp : TritonGEN_Op<"simdblockwrite">,
339-
Arguments<(ins
340-
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
341-
FixedVectorOf<[AnyTypeOf<[AnyI8, AnyI16, AnyI32, AnyI64]>]>:$val
342-
)> {
343-
374+
def TritonGEN_SubGroupBlockWriteOp : TritonGEN_Op<"sub_group_block_write"> {
344375
let summary = "simd block write";
345376

346377
let description = [{
347-
The `triton_gen.simdblockwrite` operation performs simd block write to
348-
a start address without laneId offset. The parameters are:
349-
$ptr - the base address to be written
350-
$val - the value vector to write
378+
The `triton_gen.sub_group_block_write` writes a scalar or vector for each
379+
work-item in the sub-group from pointer `ptr` as a block operation.
380+
The data is read strided, so the first value is written to:
381+
```
382+
ptr[sub_group_local_id]
383+
```
384+
and the second one is:
385+
```
386+
ptr[sub_group_local_id + sub_group_size]
387+
```
388+
etc.
389+
390+
`ptr` must be aligned to the size of the element type of `res`.
391+
392+
Example:
393+
```mlir
394+
%0 = triton_gen.sub_group_block_write %ptr, %val : !llvm.ptr<1>, vector<4xi32>
395+
```
351396
}];
352397

398+
let arguments = (ins
399+
Arg<TritonGEN_SubGroupBlockMemoryAccessPointerType, "", [MemRead]>:$ptr,
400+
TritonGEN_SubGroupBlockMemoryAccessType:$val);
401+
402+
let results = (outs);
403+
353404
let assemblyFormat = [{
354-
operands ` ` attr-dict `:` `(` type(operands) `)`
405+
$ptr `,` $val attr-dict `:` qualified(type($ptr)) `,` type($val)
355406
}];
356-
357-
let hasVerifier = 1;
358407
}
408+
359409
#endif // TRITONGEN_OPS

third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,6 @@ template <typename Op> static LogicalResult verifyMatrixInput(Op op) {
4848
return success();
4949
}
5050

51-
static LogicalResult verifySIMDBlockTy(Operation *op, VectorType vecTy) {
52-
unsigned numElems = vecTy.getNumElements();
53-
IntegerType elemTy = cast<IntegerType>(vecTy.getElementType());
54-
55-
// FIXME: Allow 16xi16 when SPIRV-LLVM translator supports it.
56-
if (numElems != 1 && numElems != 2 && numElems != 4 && numElems != 8 &&
57-
(elemTy.getWidth() != 8 || numElems != 16))
58-
return op->emitOpError("unsupported vector type");
59-
60-
return success();
61-
}
62-
6351
//===----------------------------------------------------------------------===//
6452
// gen.sub_group_reduce
6553
//===----------------------------------------------------------------------===//
@@ -438,19 +426,3 @@ LogicalResult TritonGEN::Matrix2DBlockPrefetchOp::verify() {
438426

439427
return success();
440428
}
441-
442-
//===----------------------------------------------------------------------===//
443-
// gen.simdblockread
444-
//===----------------------------------------------------------------------===//
445-
446-
LogicalResult TritonGEN::SIMDBlockReadOp::verify() {
447-
return verifySIMDBlockTy(*this, getRes().getType());
448-
}
449-
450-
//===----------------------------------------------------------------------===//
451-
// gen.simdblockwrite
452-
//===----------------------------------------------------------------------===//
453-
454-
LogicalResult TritonGEN::SIMDBlockWriteOp::verify() {
455-
return verifySIMDBlockTy(*this, getVal().getType());
456-
}

0 commit comments

Comments
 (0)