Skip to content

Commit 0dce8de

Browse files
[AxisInfo] Fix stride calculation for MulIOp (#4279)
Fixes #4275. Signed-off-by: Whitney Tsang <[email protected]>
1 parent 386840d commit 0dce8de

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

test/Analysis/intel/test-axis-info.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ tt.func public @make_tensor_ptr(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f8E5M2> {tt.
897897
// -----
898898

899899
// CHECK-LABEL: @ptr_offset
900-
tt.func public @ptr_offset(%arg0: i32) {
900+
tt.func public @ptr_offset(%arg0: i32, %arg1: tensor<128x1xi32>) {
901901
// CHECK: stride = [0, 0], contiguity = [1, 1], divisibility = [512, 512], constancy = [128, 1], constant_value = 512
902902
%cst_0 = arith.constant dense<512> : tensor<128x1xi32>
903903
// CHECK: stride = [0], contiguity = [1], divisibility = [512], constancy = [128], constant_value = 512
@@ -920,5 +920,7 @@ tt.func public @ptr_offset(%arg0: i32) {
920920
%6 = arith.muli %5, %cst_0 : tensor<128x1xi32>
921921
// CHECK: stride = [512, 0], contiguity = [1, 1], divisibility = [512, 512], constancy = [1, 64], constant_value = <none>
922922
%7 = tt.broadcast %6 : tensor<128x1xi32> -> tensor<128x64xi32>
923+
// CHECK: stride = [-1, -1], contiguity = [1, 1], divisibility = [512, 512], constancy = [1, 1], constant_value = <none>
924+
%8 = arith.muli %arg1, %cst_0 : tensor<128x1xi32>
923925
tt.return
924926
}

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,11 @@ class MulIOpAxisInfoVisitor final : public BinaryOpVisitorImpl<arith::MulIOp> {
400400
return lhs.getStride(dim) * rhs.getConstantValue().value();
401401
if (rhs.getStride(dim) > 0 && lhs.getConstantValue().has_value())
402402
return lhs.getConstantValue().value() * rhs.getStride(dim);
403-
if (lhs.getStride(dim) == 0 || rhs.getStride(dim) == 0)
403+
auto strideZero = [&](const AxisInfo axisInfo) {
404+
return axisInfo.getConstantValue().has_value() ||
405+
axisInfo.getStride(dim) == 0 || !isa<TensorType>(op.getType());
406+
};
407+
if (strideZero(lhs) && strideZero(rhs))
404408
return 0;
405409
return -1;
406410
}

0 commit comments

Comments
 (0)