Skip to content

Conversation

@matthias-springer
Copy link
Member

No description provided.

@github-actions
Copy link

github-actions bot commented Oct 13, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff de0fd64bedd23660f557833cc0108c3fb2be3918 ff124133d889c014f00a7739c267d56bf9312e18 --extensions cpp,h -- mlir/include/mlir/Conversion/LLVMCommon/Pattern.h mlir/include/mlir/Transforms/DialectConversion.h mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp mlir/lib/Dialect/Func/Transforms/DecomposeCallGraphTypes.cpp mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp mlir/lib/Dialect/SparseTensor/Transforms/Utils/SparseTensorDescriptor.cpp mlir/lib/Transforms/Utils/DialectConversion.cpp mlir/test/lib/Dialect/Func/TestDecomposeCallGraphTypes.cpp mlir/test/lib/Dialect/Test/TestPatterns.cpp
View the diff from clang-format here.
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index 20a2a10ded..e310b926b9 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -153,70 +153,74 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                        type.isVarArg());
   });
 
-/*
-  // Argument materializations convert from the new block argument types
-  // (multiple SSA values that make up a memref descriptor) back to the
-  // original block argument type. The dialect conversion framework will then
-  // insert a target materialization from the original block argument type to
-  // a legal type.
-  addArgumentMaterialization([&](OpBuilder &builder,
-                                 UnrankedMemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    if (inputs.size() == 1) {
-      // Bare pointers are not supported for unranked memrefs because a
-      // memref descriptor cannot be built just from a bare pointer.
-      return Value();
-    }
-    Value desc =
-        UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    // An argument materialization must return a value of type
-    // `resultType`, so insert a cast from the memref descriptor type
-    // (!llvm.struct) to the original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-  addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
-                                 ValueRange inputs, Location loc) {
-    Value desc;
-    if (inputs.size() == 1) {
-      // This is a bare pointer. We allow bare pointers only for function entry
-      // blocks.
-      BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
-      if (!barePtr)
-        return Value();
-      Block *block = barePtr.getOwner();
-      if (!block->isEntryBlock() ||
-          !isa<FunctionOpInterface>(block->getParentOp()))
+  /*
+    // Argument materializations convert from the new block argument types
+    // (multiple SSA values that make up a memref descriptor) back to the
+    // original block argument type. The dialect conversion framework will then
+    // insert a target materialization from the original block argument type to
+    // a legal type.
+    addArgumentMaterialization([&](OpBuilder &builder,
+                                   UnrankedMemRefType resultType,
+                                   ValueRange inputs, Location loc) {
+      if (inputs.size() == 1) {
+        // Bare pointers are not supported for unranked memrefs because a
+        // memref descriptor cannot be built just from a bare pointer.
         return Value();
-      desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
-                                               inputs[0]);
-    } else {
-      desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-    }
-    // An argument materialization must return a value of type `resultType`,
-    // so insert a cast from the memref descriptor type (!llvm.struct) to the
-    // original memref type.
-    return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
-        .getResult(0);
-  });
-
-*/
+      }
+      Value desc =
+          UnrankedMemRefDescriptor::pack(builder, loc, *this, resultType,
+    inputs);
+      // An argument materialization must return a value of type
+      // `resultType`, so insert a cast from the memref descriptor type
+      // (!llvm.struct) to the original memref type.
+      return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+          .getResult(0);
+    });
+    addArgumentMaterialization([&](OpBuilder &builder, MemRefType resultType,
+                                   ValueRange inputs, Location loc) {
+      Value desc;
+      if (inputs.size() == 1) {
+        // This is a bare pointer. We allow bare pointers only for function
+    entry
+        // blocks.
+        BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
+        if (!barePtr)
+          return Value();
+        Block *block = barePtr.getOwner();
+        if (!block->isEntryBlock() ||
+            !isa<FunctionOpInterface>(block->getParentOp()))
+          return Value();
+        desc = MemRefDescriptor::fromStaticShape(builder, loc, *this,
+    resultType, inputs[0]); } else { desc = MemRefDescriptor::pack(builder, loc,
+    *this, resultType, inputs);
+      }
+      // An argument materialization must return a value of type `resultType`,
+      // so insert a cast from the memref descriptor type (!llvm.struct) to the
+      // original memref type.
+      return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
+          .getResult(0);
+    });
+
+  */
   // Add generic source and target materializations to handle cases where
   // non-LLVM types persist after an LLVM conversion.
   addSourceMaterialization([&](OpBuilder &builder, Type resultType,
                                ValueRange inputs, Location loc) {
-    //if (inputs.size() != 1)
-    //  return Value();
+    // if (inputs.size() != 1)
+    //   return Value();
 
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, inputs)
         .getResult(0);
   });
   addSourceMaterialization([&](OpBuilder &builder, MemRefType resultType,
                                ValueRange inputs, Location loc) {
-    if (inputs.size()== 1 && isa<LLVM::LLVMStructType>(inputs.front().getType())) return Value();
+    if (inputs.size() == 1 &&
+        isa<LLVM::LLVMStructType>(inputs.front().getType()))
+      return Value();
 
     Value desc;
-    if (inputs.size() == 1 && isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
+    if (inputs.size() == 1 &&
+        isa<LLVM::LLVMPointerType>(inputs.front().getType())) {
       // This is a bare pointer. We allow bare pointers only for function entry
       // blocks.
       BlockArgument barePtr = dyn_cast<BlockArgument>(inputs.front());
@@ -229,10 +233,10 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       desc = MemRefDescriptor::fromStaticShape(builder, loc, *this, resultType,
                                                inputs[0]);
     } else {
-      //llvm::errs() << "pack elems: " << inputs.size() << "\n";
-      //llvm::errs() << inputs[0] << "\n";
+      // llvm::errs() << "pack elems: " << inputs.size() << "\n";
+      // llvm::errs() << inputs[0] << "\n";
       desc = MemRefDescriptor::pack(builder, loc, *this, resultType, inputs);
-      //llvm::errs() << "done packing\n";
+      // llvm::errs() << "done packing\n";
     }
     // An argument materialization must return a value of type `resultType`,
     // so insert a cast from the memref descriptor type (!llvm.struct) to the
@@ -240,8 +244,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
     return builder.create<UnrealizedConversionCastOp>(loc, resultType, desc)
         .getResult(0);
   });
-  addSourceMaterialization([&](OpBuilder &builder, UnrankedMemRefType resultType,
-                               ValueRange inputs, Location loc) {
+  addSourceMaterialization([&](OpBuilder &builder,
+                               UnrankedMemRefType resultType, ValueRange inputs,
+                               Location loc) {
     if (inputs.size() == 1) {
       // Bare pointers are not supported for unranked memrefs because a
       // memref descriptor cannot be built just from a bare pointer.
@@ -264,8 +269,8 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
         .getResult(0);
   });
   addTargetMaterialization([&](OpBuilder &builder, Type resultType,
-                               ValueRange inputs,
-                               Location loc, Type originalType) -> Value {
+                               ValueRange inputs, Location loc,
+                               Type originalType) -> Value {
     llvm::errs() << "TARGET MAT: -> " << resultType << "\n";
     if (!originalType) {
       llvm::errs() << " -- no orig\n";
@@ -275,8 +280,9 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
       assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
       if (inputs.size() == 1) {
         Value input = inputs.front();
-        if (auto castOp =input.getDefiningOp<UnrealizedConversionCastOp>()) {
-          if (castOp.getInputs().size() == 1 && isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
+        if (auto castOp = input.getDefiningOp<UnrealizedConversionCastOp>()) {
+          if (castOp.getInputs().size() == 1 &&
+              isa<LLVM::LLVMPointerType>(castOp.getInputs()[0].getType())) {
             input = castOp.getInputs()[0];
           }
         }
@@ -290,23 +296,23 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
             !isa<FunctionOpInterface>(block->getParentOp()))
           return Value();
         // Bare ptr
-        return MemRefDescriptor::fromStaticShape(builder, loc, *this, memrefType,
-                                                 input);
+        return MemRefDescriptor::fromStaticShape(builder, loc, *this,
+                                                 memrefType, input);
       }
       return MemRefDescriptor::pack(builder, loc, *this, memrefType, inputs);
     }
     if (auto memrefType = dyn_cast<UnrankedMemRefType>(originalType)) {
       assert(isa<LLVM::LLVMStructType>(resultType) && "expected struct type");
       if (inputs.size() == 1) {
-          // Bare pointers are not supported for unranked memrefs because a
-          // memref descriptor cannot be built just from a bare pointer.
-          return Value();
+        // Bare pointers are not supported for unranked memrefs because a
+        // memref descriptor cannot be built just from a bare pointer.
+        return Value();
       }
-      return UnrankedMemRefDescriptor::pack(builder, loc, *this,
-                                                    memrefType, inputs);
+      return UnrankedMemRefDescriptor::pack(builder, loc, *this, memrefType,
+                                            inputs);
     }
 
-      return Value();
+    return Value();
   });
 
   // Integer memory spaces map to themselves.
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9b9682148b..a40678f912 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -122,15 +122,16 @@ struct ConversionValueMapping {
     if (it == mapping.end())
       return Value();
     const SmallVector<Value, 1> &repl = it->second;
-    if (repl.size() != 1) return Value();
-    return repl.front();
-/*
-    if (!mapping.contains(from)) return Value();
-    auto it = llvm::find(mapping, from);
-    const SmallVector<Value, 1> &repl = it->second;
-    if (repl.size() != 1) return Value();
+    if (repl.size() != 1)
+      return Value();
     return repl.front();
-    */
+    /*
+        if (!mapping.contains(from)) return Value();
+        auto it = llvm::find(mapping, from);
+        const SmallVector<Value, 1> &repl = it->second;
+        if (repl.size() != 1) return Value();
+        return repl.front();
+        */
   }
 
   /// Find the most recently mapped values for the given value. If the value is
@@ -1299,7 +1300,8 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
     SmallVector<Value, 1> vals = mapping.lookupOrDefault(operand);
     ValueRange castValues = buildUnresolvedMaterialization(
         MaterializationKind::Target, computeInsertPoint(vals), operandLoc,
-        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType, currentTypeConverter);
+        /*inputs=*/vals, /*outputTypes=*/legalTypes, /*originalType=*/origType,
+        currentTypeConverter);
 
     mapping.mapMaterialization(vals, castValues);
     remapped.push_back(castValues);
@@ -1454,7 +1456,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
 /// of input operands.
 ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
     MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
-    ValueRange inputs, TypeRange outputTypes, Type originalType, const TypeConverter *converter) {
+    ValueRange inputs, TypeRange outputTypes, Type originalType,
+    const TypeConverter *converter) {
   // Avoid materializing an unnecessary cast.
   if (TypeRange(inputs) == outputTypes)
     return inputs;
@@ -1465,7 +1468,8 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
   auto convertOp =
       builder.create<UnrealizedConversionCastOp>(loc, outputTypes, inputs);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind, originalType);
+  appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
+                                                  originalType);
   return convertOp.getResults();
 }
 
@@ -1686,7 +1690,8 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
   });
   impl->appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from);
   SmallVector<Value, 1> mapped = impl->mapping.lookupOrDefault(from);
-  assert(mapped.size() == 1 && "replaceUsesOfBlockArgument is not supported for 1:N replacements");
+  assert(mapped.size() == 1 &&
+         "replaceUsesOfBlockArgument is not supported for 1:N replacements");
   impl->mapping.map(mapped.front(), to);
 }
 
@@ -2599,7 +2604,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
           rewrite->getOriginalType());
       break;
     case MaterializationKind::Source:
-      assert(op.getNumResults() == 1 && "*:N source materializations are not supported");
+      assert(op.getNumResults() == 1 &&
+             "*:N source materializations are not supported");
       Value sourceMat = converter->materializeSourceConversion(
           rewriter, op->getLoc(), op.getResultTypes().front(), inputOperands);
       if (sourceMat)
@@ -2617,7 +2623,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
   InFlightDiagnostic diag = op->emitError()
                             << "failed to legalize unresolved materialization "
                                "from ("
-                            << inputOperands.getTypes() << ") to (" << op.getResultTypes()
+                            << inputOperands.getTypes() << ") to ("
+                            << op.getResultTypes()
                             << ") that remained live after conversion";
   diag.attachNote(op->getUsers().begin()->getLoc())
       << "see existing live user here: " << *op->getUsers().begin();
@@ -2741,7 +2748,8 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
     for (Value originalValue : replacedValues) {
       // If this value is directly replaced with a value of the same type,
       // there is nothing to do.
-      Value repl = rewriterImpl.mapping.lookupDirectSingleReplacement(originalValue);
+      Value repl =
+          rewriterImpl.mapping.lookupDirectSingleReplacement(originalValue);
       if (repl && repl.getType() == originalValue.getType())
         continue;
       // If the type of this value changed and the value is still live, we need
@@ -2761,8 +2769,8 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
       Value castValue = rewriterImpl.buildUnresolvedMaterialization(
           MaterializationKind::Source, computeInsertPoint(newValues),
           originalValue.getLoc(),
-          /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(), /*originalType=*/Type(),
-          converter)[0];
+          /*inputs=*/newValues, /*outputTypes=*/originalValue.getType(),
+          /*originalType=*/Type(), converter)[0];
       rewriterImpl.mapping.mapMaterialization(newValues, {castValue});
       llvm::append_range(inverseMapping[castValue], newValues);
     }
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 9154964465..9fef1e4819 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1076,7 +1076,8 @@ struct TestUpdateConsumerType : public ConversionPattern {
   LogicalResult
   matchAndRewrite(Operation *op, ArrayRef<Value> operands,
                   ConversionPatternRewriter &rewriter) const final {
-    llvm::errs() << "TestUpdateConsumerType operand: " << operands.front() << "\n";
+    llvm::errs() << "TestUpdateConsumerType operand: " << operands.front()
+                 << "\n";
     // Verify that the incoming operand has been successfully remapped to F64.
     if (!operands[0].getType().isF64())
       return failure();

@matthias-springer matthias-springer force-pushed the users/matthias-springer/tmp_no_arg_mat branch from 7ec251b to 820769d Compare October 13, 2024 17:33
@matthias-springer matthias-springer force-pushed the users/matthias-springer/tmp_no_arg_mat branch 6 times, most recently from 0c57792 to 7a7d81a Compare October 31, 2024 09:53
@matthias-springer matthias-springer force-pushed the users/matthias-springer/tmp_no_arg_mat branch 2 times, most recently from 52ebfbc to 76ccdee Compare November 5, 2024 04:15
Apply suggestions from code review

Co-authored-by: Markus Böck <[email protected]>

address comments
do not build argument materializations anymore

fix more tests

Fix decompose call graph test
@matthias-springer matthias-springer force-pushed the users/matthias-springer/tmp_no_arg_mat branch from 76ccdee to ff12413 Compare November 13, 2024 06:07
@matthias-springer matthias-springer changed the base branch from main to users/matthias-springer/replace_with_multiple November 13, 2024 06:07
@matthias-springer matthias-springer force-pushed the users/matthias-springer/replace_with_multiple branch 2 times, most recently from 001453c to f5ed959 Compare November 14, 2024 01:27
Base automatically changed from users/matthias-springer/replace_with_multiple to main November 14, 2024 01:28
@matthias-springer
Copy link
Member Author

This PR was split into multiple PRs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants