Skip to content

Commit 67f647a

Browse files
authored
[IR] Add type checks for atomic_cas (#7578)
1 parent 948ba8f commit 67f647a

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,14 @@ def TT_AtomicRMWOp : TT_Op<"atomic_rmw", [
383383
}];
384384
}
385385

386-
def TT_AtomicCASOp : TT_Op<"atomic_cas", [SameOperandsAndResultShape,
387-
SameOperandsAndResultEncoding]> {
386+
def TT_AtomicCASOp : TT_Op<"atomic_cas", [
387+
SameOperandsAndResultShape,
388+
SameOperandsAndResultEncoding,
389+
TypesMatchWith<"ptr type matches cmp type", "cmp", "ptr",
390+
"getPointerTypeSameShape($_self)">,
391+
TypesMatchWith<"ptr type matches value type", "val", "ptr",
392+
"getPointerTypeSameShape($_self)">
393+
]> {
388394
let summary = "atomic cas";
389395

390396
let description = [{

test/Triton/invalid.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,3 +546,21 @@ tt.func @unsplat_invalid(%arg0: tensor<128xf32>) {
546546
%0 = tt.unsplat %arg0 : tensor<128xf32>
547547
tt.return
548548
}
549+
550+
// -----
551+
552+
tt.func @atomic_cas_different_elem_types(%arg0: tensor<128x!tt.ptr<f32>>, %arg1: tensor<128xi32>) {
553+
%cmp = arith.constant dense<0> : tensor<128xi32>
554+
// expected-error @below {{'tt.atomic_cas' op failed to verify that ptr type matches cmp type}}
555+
%0 = tt.atomic_cas relaxed, gpu, %arg0, %cmp, %arg1 : (tensor<128x!tt.ptr<f32>>, tensor<128xi32>, tensor<128xi32>) -> tensor<128xi32>
556+
tt.return
557+
}
558+
559+
// -----
560+
561+
tt.func @atomic_cas_different_elem_types(%arg0: tensor<128x!tt.ptr<f32>>, %arg1: tensor<128xi32>) {
562+
%cmp = arith.constant dense<0.0> : tensor<128xf32>
563+
// expected-error @below {{'tt.atomic_cas' op failed to verify that ptr type matches value type}}
564+
%0 = tt.atomic_cas relaxed, gpu, %arg0, %cmp, %arg1 : (tensor<128x!tt.ptr<f32>>, tensor<128xf32>, tensor<128xi32>) -> tensor<128xi32>
565+
tt.return
566+
}

0 commit comments

Comments
 (0)