Skip to content

Commit 82fae4e

Browse files
authored
[ANALYSIS] Don't consider descending sequences as contiguous in AxisInfoAnalysis (#4871)
Contiguity is used to issue wide load operations instead of multiple loads. This always assumes that the address of the first element in a sequence can be used to load the sequence of elements. If a sequence is diminishing, it leads to a wrong wide load operation. This patch fixes that by not preserving a contiguity of RHS for SubIOp operation. - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [ ] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [x] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [ ] This PR does not need a test because `FILL THIS IN`. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.) --------- Signed-off-by: Ilya Enkovich <[email protected]>
1 parent e66f5ab commit 82fae4e

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

lib/Analysis/AxisInfo.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,11 @@ class AddSubOpAxisInfoVisitor final : public BinaryOpVisitorImpl<OpTy> {
278278
private:
279279
int64_t getContiguity(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs,
280280
int dim) override {
281+
// Contiguity assumes an increasing sequence. So for SubIOp contiguous
282+
// RHS doesn't produce a contiguous result.
283+
if (isa<arith::SubIOp>(op))
284+
return gcd(lhs.getContiguity(dim), rhs.getConstancy(dim));
285+
281286
return std::max(gcd(lhs.getConstancy(dim), rhs.getContiguity(dim)),
282287
gcd(lhs.getContiguity(dim), rhs.getConstancy(dim)));
283288
}

python/test/regression/test_functional_regressions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,3 +224,18 @@ def grid(META):
224224
BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, #
225225
num_stages=num_stages)
226226
torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2)
227+
228+
229+
def test_reverse_range(device):
230+
231+
@triton.jit
232+
def kernel(in_ptr, out_ptr):
233+
x0 = tl.arange(0, 512)
234+
tmp0 = tl.load(in_ptr + (512 - x0))
235+
tl.store(out_ptr + x0, tmp0)
236+
237+
data = torch.randn((516, ), dtype=torch.float32, device=device)
238+
res = torch.empty((512, ), dtype=torch.float32, device=device)
239+
kernel[(1, )](data, res)
240+
ref = torch.flip(data[1:513], [0])
241+
assert (res == ref).all()

test/Analysis/test-alignment.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@ tt.func @sub() {
9797
%1 = arith.constant dense<1> : tensor<128xi32>
9898
// CHECK-NEXT: contiguity = [128], divisibility = [1], constancy = [1], constant_value = <none>
9999
%2 = arith.subi %0, %1 : tensor<128xi32>
100+
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
101+
%3 = arith.subi %1, %0 : tensor<128xi32>
100102
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [128], constant_value = 129
101-
%3 = arith.constant dense<129> : tensor<128xi32>
103+
%4 = arith.constant dense<129> : tensor<128xi32>
102104
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128
103-
%4 = arith.subi %3, %1 : tensor<128xi32>
105+
%5 = arith.subi %4, %1 : tensor<128xi32>
104106
tt.return
105107
}
106108

0 commit comments

Comments
 (0)