Skip to content

Conversation

@Groverkss
Copy link
Member

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Oct 29, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Kunwar Grover (Groverkss)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/114137.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+26-22)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+42)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 6184225cb6285d..ee8dccf025a0c6 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -17,6 +17,7 @@
 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
+#include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -40,22 +41,9 @@ using namespace mlir;
 /// Returns the integer value from the first valid input element, assuming Value
 /// inputs are defined by a constant index ops and Attribute inputs are integer
 /// attributes.
-static uint64_t getFirstIntValue(ValueRange values) {
-  return values[0].getDefiningOp<arith::ConstantIndexOp>().value();
-}
-static uint64_t getFirstIntValue(ArrayRef<Attribute> attr) {
-  return cast<IntegerAttr>(attr[0]).getInt();
-}
 static uint64_t getFirstIntValue(ArrayAttr attr) {
   return (*attr.getAsValueRange<IntegerAttr>().begin()).getZExtValue();
 }
-static uint64_t getFirstIntValue(ArrayRef<OpFoldResult> foldResults) {
-  auto attr = foldResults[0].dyn_cast<Attribute>();
-  if (attr)
-    return getFirstIntValue(attr);
-
-  return getFirstIntValue(ValueRange{foldResults[0].get<Value>()});
-}
 
 /// Returns the number of bits for the given scalar/vector type.
 static int getNumBits(Type type) {
@@ -157,9 +145,6 @@ struct VectorExtractOpConvert final
   LogicalResult
   matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    if (extractOp.hasDynamicPosition())
-      return failure();
-
     Type dstType = getTypeConverter()->convertType(extractOp.getType());
     if (!dstType)
       return failure();
@@ -169,9 +154,17 @@ struct VectorExtractOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(extractOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
-        extractOp, adaptor.getVector(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(extractOp.getMixedPosition()[0]);
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
+          extractOp, dstType, adaptor.getVector(),
+          rewriter.getI32ArrayAttr(id.value()));
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorExtractDynamicOp>(
+          extractOp, dstType, adaptor.getVector(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
@@ -249,9 +242,20 @@ struct VectorInsertOpConvert final
       return success();
     }
 
-    int32_t id = getFirstIntValue(insertOp.getMixedPosition());
-    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
-        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    std::optional<int64_t> id =
+        getConstantIntValue(insertOp.getMixedPosition()[0]);
+
+    //    rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+    //        insertOp, adaptor.getSource(), adaptor.getDest(), id);
+    //    return success();
+
+    if (id.has_value())
+      rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
+          insertOp, adaptor.getSource(), adaptor.getDest(), id.value());
+    else
+      rewriter.replaceOpWithNewOp<spirv::VectorInsertDynamicOp>(
+          insertOp, insertOp.getDest(), adaptor.getSource(),
+          adaptor.getDynamicPosition()[0]);
     return success();
   }
 };
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 25ec5d0159bd5d..62210108aa73cf 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -186,6 +186,26 @@ func.func @extract_size1_vector(%arg0 : vector<1xf32>) -> f32 {
 
 // -----
 
+// CHECK-LABEL: @extract_dynamic
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ARG1:.*]]: index
+//       CHECK:   %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32
+//       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @extract_dynamic(%arg0 : vector<4xf32>, %id : index) -> f32 {
+  %0 = vector.extract %arg0[%id] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// CHECK-LABEL: @extract_dynamic_cst
+//  CHECK-SAME: %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeExtract %[[V]][1 : i32] : vector<4xf32>
+func.func @extract_dynamic_cst(%arg0 : vector<4xf32>) -> f32 {
+  %idx = arith.constant 1 : index
+  %0 = vector.extract %arg0[%idx] : f32 from vector<4xf32>
+  return %0: f32
+}
+
+// -----
+
 // CHECK-LABEL: @insert
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[S:.*]]: f32
 //       CHECK:   spirv.CompositeInsert %[[S]], %[[V]][2 : i32] : f32 into vector<4xf32>
@@ -216,6 +236,28 @@ func.func @insert_size1_vector(%arg0 : vector<1xf32>, %arg1: f32) -> vector<1xf3
 
 // -----
 
+// CHECK-LABEL: @insert_dynamic
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>, %[[ARG2:.*]]: index
+//       CHECK: %[[ID:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : index to i32
+//       CHECK:   spirv.VectorInsertDynamic %[[VAL]], %[[V]][%[[ID]]] : vector<4xf32>, i32
+func.func @insert_dynamic(%val: f32, %arg0 : vector<4xf32>, %id : index) -> vector<4xf32> {
+  %0 = vector.insert %val, %arg0[%id] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @insert_dynamic_cst
+//  CHECK-SAME: %[[VAL:.*]]: f32, %[[V:.*]]: vector<4xf32>
+//       CHECK:   spirv.CompositeInsert %[[VAL]], %[[V]][2 : i32] : f32 into vector<4xf32>
+func.func @insert_dynamic_cst(%val: f32, %arg0 : vector<4xf32>) -> vector<4xf32> {
+  %idx = arith.constant 2 : index
+  %0 = vector.insert %val, %arg0[%idx] : f32 into vector<4xf32>
+  return %0: vector<4xf32>
+}
+
+// -----
+
 // CHECK-LABEL: @extract_element
 //  CHECK-SAME: %[[V:.*]]: vector<4xf32>, %[[ID:.*]]: i32
 //       CHECK:   spirv.VectorExtractDynamic %[[V]][%[[ID]]] : vector<4xf32>, i32

@Groverkss Groverkss requested a review from kuhar November 5, 2024 09:02
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@Groverkss Groverkss force-pushed the vector-insert-extract-dynamic-spirv branch from 9d41a73 to 46795c8 Compare November 5, 2024 23:09
@Groverkss Groverkss merged commit c96a85a into llvm:main Nov 6, 2024
6 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants