Skip to content

Conversation

@lialan
Copy link
Member

@lialan lialan commented Oct 30, 2024

  • Supports vector.load and vector.transfer_read ops.
  • In the case of dynamic indexing, use per-element insertion/extraction to build desired narrow type vectors.
  • Fixed wrong function comment of getCompressedMaskOp.

@llvmbot
Copy link
Member

llvmbot commented Oct 30, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: lialan (lialan)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/114169.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+90-26)
  • (added) mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir (+53)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 1d6f8a991d9b5b..04514725c3aeee 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
 #include "mlir/IR/BuiltinAttributes.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -149,6 +150,61 @@ static Value insertSubvectorInto(RewriterBase &rewriter, Location loc,
                                                        dest, offsets, strides);
 }
 
+static void dynamicallyExtractElementsToVector(
+    RewriterBase &rewriter, Location loc, TypedValue<VectorType> srcVec,
+    Value destVec, OpFoldResult srcOffsetVar, int64_t loopSize) {
+  /*
+  // Create affine maps for the lower and upper bounds
+  AffineMap lowerBoundMap = AffineMap::getConstantMap(0, rewriter.getContext());
+  AffineMap upperBoundMap =
+      AffineMap::getConstantMap(loopSize, rewriter.getContext());
+
+  auto forLoop = rewriter.create<affine::AffineForOp>(
+      loc, ValueRange{}, lowerBoundMap, ValueRange{}, upperBoundMap, 1,
+      ArrayRef<Value>(destVec));
+
+  OpBuilder builder =
+      OpBuilder::atBlockEnd(forLoop.getBody(), rewriter.getListener());
+
+  auto iv = forLoop.getInductionVar();
+
+  auto loopDestVec = forLoop.getRegionIterArgs()[0];
+  auto extractLoc = builder.create<arith::AddIOp>(
+      loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(), iv);
+  auto extractElemOp = builder.create<vector::ExtractElementOp>(
+      loc, elemType, srcVec, extractLoc);
+  auto insertElemOp = builder.create<vector::InsertElementOp>(
+      loc, extractElemOp, loopDestVec, iv);
+  builder.create<affine::AffineYieldOp>(loc,
+                                        ValueRange{insertElemOp->getResult(0)});
+  return forLoop->getResult(0);
+  */
+  for (int i = 0; i < loopSize; ++i) {
+    Value extractLoc;
+    if (i == 0) {
+      extractLoc = srcOffsetVar.dyn_cast<Value>();
+    } else {
+      extractLoc = rewriter.create<arith::AddIOp>(
+          loc, rewriter.getIndexType(), srcOffsetVar.dyn_cast<Value>(),
+          rewriter.create<arith::ConstantIndexOp>(loc, i));
+    }
+    auto extractOp =
+        rewriter.create<vector::ExtractOp>(loc, srcVec, extractLoc);
+    rewriter.create<vector::InsertOp>(loc, extractOp, destVec, i);
+  }
+}
+
+static TypedValue<VectorType>
+emulatedVectorLoad(ConversionPatternRewriter &rewriter, Location loc,
+                   Value base, OpFoldResult linearizedIndices, int64_t numBytes,
+                   int64_t scale, Type oldElememtType, Type newElementType) {
+  auto newLoad = rewriter.create<vector::LoadOp>(
+      loc, VectorType::get(numBytes, newElementType), base,
+      getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
+  return rewriter.create<vector::BitCastOp>(
+      loc, VectorType::get(numBytes * scale, oldElememtType), newLoad);
+};
+
 namespace {
 
 //===----------------------------------------------------------------------===//
@@ -380,26 +436,29 @@ struct ConvertVectorLoad final : OpConversionPattern<vector::LoadOp> {
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic intra vector offset
-      return failure();
-    }
-
+    // always load enough elements which can cover the original elements
+    auto maxintraVectorOffset =
+        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
-    auto newLoad = rewriter.create<vector::LoadOp>(
-        loc, VectorType::get(numElements, newElementType), adaptor.getBase(),
-        getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
-
-    Value result = rewriter.create<vector::BitCastOp>(
-        loc, VectorType::get(numElements * scale, oldElementType), newLoad);
+        llvm::divideCeil(maxintraVectorOffset + origElements, scale);
+    Value result =
+        emulatedVectorLoad(rewriter, loc, adaptor.getBase(), linearizedIndices,
+                           numElements, scale, oldElementType, newElementType);
 
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                      *foldedIntraVectorOffset, origElements);
+      }
+      rewriter.replaceOp(op, result);
+    } else {
+      auto resultVector = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      dynamicallyExtractElementsToVector(
+          rewriter, loc, dyn_cast<TypedValue<VectorType>>(result), resultVector,
+          linearizedInfo.intraVectorOffset, origElements);
+      rewriter.replaceOp(op, resultVector);
     }
-
-    rewriter.replaceOp(op, result);
     return success();
   }
 };
@@ -604,13 +663,10 @@ struct ConvertVectorTransferRead final
             ? getConstantIntValue(linearizedInfo.intraDataOffset)
             : 0;
 
-    if (!foldedIntraVectorOffset) {
-      // unimplemented case for dynamic inra-vector offset
-      return failure();
-    }
-
+    auto maxIntraVectorOffset =
+        foldedIntraVectorOffset ? *foldedIntraVectorOffset : scale - 1;
     auto numElements =
-        llvm::divideCeil(*foldedIntraVectorOffset + origElements, scale);
+        llvm::divideCeil(maxIntraVectorOffset + origElements, scale);
 
     auto newRead = rewriter.create<vector::TransferReadOp>(
         loc, VectorType::get(numElements, newElementType), adaptor.getSource(),
@@ -621,9 +677,17 @@ struct ConvertVectorTransferRead final
         loc, VectorType::get(numElements * scale, oldElementType), newRead);
 
     Value result = bitCast->getResult(0);
-    if (isUnalignedEmulation) {
-      result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
-                                    *foldedIntraVectorOffset, origElements);
+    if (foldedIntraVectorOffset) {
+      if (isUnalignedEmulation) {
+        result = extractSubvectorFrom(rewriter, loc, op.getType(), result,
+                                      *foldedIntraVectorOffset, origElements);
+      }
+    } else {
+      result = rewriter.create<arith::ConstantOp>(
+          loc, op.getType(), rewriter.getZeroAttr(op.getType()));
+      dynamicallyExtractElementsToVector(rewriter, loc, bitCast, result,
+                                         linearizedInfo.intraVectorOffset,
+                                         origElements);
     }
     rewriter.replaceOp(op, result);
 
diff --git a/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
new file mode 100644
index 00000000000000..a92e62538c5332
--- /dev/null
+++ b/mlir/test/Dialect/Vector/vector-emulate-narrow-type-unaligned-dynamic.mlir
@@ -0,0 +1,53 @@
+// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
+
+// CHECK: #map = affine_map<()[s0, s1] -> ((s0 * 3 + s1) floordiv 4)>
+// CHECK: #map1 = affine_map<()[s0, s1] -> ((s0 * 3 + s1) mod 4)>
+func.func @vector_load_i2(%arg1: index, %arg2: index) -> vector<3x3xi2> {
+    %0 = memref.alloc() : memref<3x3xi2>
+    %c0 = arith.constant 0 : index
+    %c2 = arith.constant 2 : index
+    %cst = arith.constant dense<0> : vector<3x3xi2>
+    %1 = vector.load %0[%arg1, %arg2] : memref<3x3xi2>, vector<3xi2>
+    %2 = vector.insert %1, %cst [0] : vector<3xi2> into vector<3x3xi2>
+    return %2 : vector<3x3xi2>
+}
+
+// CHECK: func @vector_load_i2
+// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[EMULATED_LOAD:.+]] = vector.load %alloc[%[[LOADADDR1]]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EMULATED_LOAD]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[ZERO:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[OFFSET:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[OFFSET]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[OFFSET2:.+]] = arith.addi %1, %c2 : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[OFFSET2]]] : i2 from vector<8xi2>
+
+//-----
+
+func.func @vector_transfer_read_i2(%arg1: index, %arg2: index) -> vector<3xi2> {
+ %0 = memref.alloc() : memref<3x3xi2>
+ %c0i2 = arith.constant 0 : i2
+ %1 = vector.transfer_read %0[%arg1, %arg2], %c0i2 {in_bounds = [true]} : memref<3x3xi2>, vector<3xi2>
+ return %1 : vector<3xi2>
+}
+
+// CHECK: func @vector_transfer_read_i2
+// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
+// CHECK: %[[C0:.+]] = arith.extui %c0_i2 : i2 to i8
+// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%arg0, %arg1]
+// CHECK: %[[LOADADDR2:.+]] = affine.apply #map1()[%arg0, %arg1]
+// CHECK: %[[READ:.+]] = vector.transfer_read %[[ALLOC]][%[[LOADADDR1]]], %[[C0]] : memref<3xi8>, vector<2xi8>
+// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[READ]] : vector<2xi8> to vector<8xi2>
+// CHECK: %[[CST:.+]] = arith.constant dense<0> : vector<3xi2>
+// CHECK: %[[EXTRACT:.+]] = vector.extract %[[BITCAST]][%[[LOADADDR2]]] : i2 from vector<8xi2>
+// CHECK: %[[C1:.+]] = arith.constant 1 : index
+// CHECK: %[[ADDI:.+]] = arith.addi %[[LOADADDR2]], %[[C1]] : index
+// CHECK: %[[EXTRACT2:.+]] = vector.extract %[[BITCAST]][%[[ADDI]]] : i2 from vector<8xi2>
+// CHECK: %[[C2:.+]] = arith.constant 2 : index
+// CHECK: %[[ADDI2:.+]] = arith.addi %[[LOADADDR2]], %[[C2]] : index
+// CHECK: %[[EXTRACT3:.+]] = vector.extract %[[BITCAST]][%[[ADDI2]]] : i2 from vector<8xi2>

@github-actions
Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Discourse for more information.

@lialan lialan force-pushed the lialan/dynamic_index branch from cc7b19b to 2effa6a Compare October 30, 2024 13:12
@lialan lialan requested a review from banach-space October 30, 2024 17:32
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missing where there is a test for dynamic indices?

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also added more static unaligned loading tests

Please send a separate PR.

@lialan lialan force-pushed the lialan/dynamic_index branch from 328a0df to c47d30f Compare November 1, 2024 03:08
@lialan lialan force-pushed the lialan/dynamic_index branch from c47d30f to b777a60 Compare November 1, 2024 03:51
@lialan lialan requested a review from banach-space November 1, 2024 03:59
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the improvements. I am sending a few more suggestion, nothing major. Feel free to ignore my nits.

Comment on lines 131 to 132
/// A wrapper function for emitting `vector.extract_strided_slice`. The vector
/// has to be of 1-D shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Rather than saying that the vector has to be 1-D, why not say "Extracts 1-D subvector from a 1-D vector". This way, the intent becomes clearer and the requirements are implied ;-)

Just so that you don't have to guess what I had in mind:

/// A wrapper function to extract a 1-D subvector from the 1-D source vector.

Also, could you remind my why use vector.extract_strided_slice rather than vector.extract?

Btw, none of this is a blocker for this PR. These are nice-to-have improvements, thanks!

Copy link
Member Author

@lialan lialan Nov 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

if I understand correctly, vector.extract can extract a single element or the whole innermost dimension, but if we want to operate on a certain part of the inner most dimension then we will have to use vector.extract_strided_slice?

Too bad the docs are not super formal so this is how I read it.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, just some minor comments.

@lialan lialan requested a review from dcaballe November 4, 2024 16:12
@github-actions
Copy link

github-actions bot commented Nov 4, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@lialan lialan requested a review from hanhanW November 4, 2024 20:45
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Just final nit.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

// CHECK-LABEL: func @vector_load_i2_dynamic_indexing(
// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index) -> vector<3xi2>
// CHECK: %[[ALLOC:.+]]= memref.alloc() : memref<3xi8>
// CHECK: %[[LOADADDR1:.+]] = affine.apply #map()[%[[ARG0]], %[[ARG1]]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not use #map directly. Those could be anything, the name is not semantically useful. You need to capture what #map is defined as and then make sure it is used the right way

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The mapping still seems off. It needs to pick the maps for each tests. Unblocking, but please fix before landing.

@hanhanW hanhanW merged commit ce112a7 into llvm:main Nov 5, 2024
8 checks passed
@lialan lialan deleted the lialan/dynamic_index branch November 5, 2024 17:35
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Dec 4, 2024
…fc4896927

Local branch amd-gfx a12fc48 Merged main:0a68171b3c67503f7143856580f1b22a93ef566e into amd-gfx:cbff18bd3aba
Remote branch main ce112a7 [MLIR] support dynamic indexing in `VectorEmulateNarrowTypes` (llvm#114169)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants