Skip to content

Conversation

@jpienaar
Copy link
Member

@jpienaar jpienaar commented Oct 8, 2025

Done very mechanically.

@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-mlir

Author: Jacques Pienaar (jpienaar)

Changes

Done very mechanically.


Patch is 23.67 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/162429.diff

5 Files Affected:

  • (modified) mlir/include/mlir-c/Rewrite.h (+97-1)
  • (modified) mlir/lib/Bindings/Python/Rewrite.cpp (+118)
  • (modified) mlir/lib/CAPI/Transforms/Rewrite.cpp (+141-4)
  • (modified) mlir/test/CAPI/rewrite.c (+47)
  • (added) mlir/test/python/rewrite.py (+107)
diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h
index 77be1f480eacf..20e078a3c1e81 100644
--- a/mlir/include/mlir-c/Rewrite.h
+++ b/mlir/include/mlir-c/Rewrite.h
@@ -36,6 +36,26 @@ extern "C" {
 DEFINE_C_API_STRUCT(MlirRewriterBase, void);
 DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void);
+
+/// Greedy rewrite strictness levels.
+typedef enum {
+  /// No restrictions wrt. which ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP,
+  /// Only pre-existing and newly created ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS,
+  /// Only pre-existing ops are processed.
+  MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS
+} MlirGreedyRewriteStrictness;
+
+/// Greedy simplify region levels.
+typedef enum {
+  /// Disable region control-flow simplification.
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED,
+  /// Run the normal simplification (e.g. dead args elimination).
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL,
+  /// Run extra simplifications (e.g. block merging).
+  MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE
+} MlirGreedySimplifyRegionLevel;
 DEFINE_C_API_STRUCT(MlirRewritePatternSet, void);
 DEFINE_C_API_STRUCT(MlirPatternRewriter, void);
 
@@ -308,7 +328,83 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp(
 
 MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily(
     MlirModule op, MlirFrozenRewritePatternSet patterns,
-    MlirGreedyRewriteDriverConfig);
+    MlirGreedyRewriteDriverConfig config);
+
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+/// Creates a greedy rewrite driver configuration with default settings.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteDriverConfig
+mlirGreedyRewriteDriverConfigCreate();
+
+/// Destroys a greedy rewrite driver configuration.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config);
+
+/// Sets the maximum number of iterations for the greedy rewrite driver.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxIterations(
+    MlirGreedyRewriteDriverConfig config, int64_t maxIterations);
+
+/// Sets the maximum number of rewrites within an iteration.
+/// Use -1 for no limit.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites);
+
+/// Sets whether to use top-down traversal for the initial population of the
+/// worklist.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal);
+
+/// Enables or disables folding during greedy rewriting.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableFolding(
+    MlirGreedyRewriteDriverConfig config, bool enable);
+
+/// Sets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigSetStrictness(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedyRewriteStrictness strictness);
+
+/// Sets the region simplification level.
+MLIR_CAPI_EXPORTED void
+mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedySimplifyRegionLevel level);
+
+/// Enables or disables constant CSE.
+MLIR_CAPI_EXPORTED void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+    MlirGreedyRewriteDriverConfig config, bool enable);
+
+/// Gets the maximum number of iterations for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the maximum number of rewrites within an iteration.
+MLIR_CAPI_EXPORTED int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether top-down traversal is used for initial worklist population.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether folding is enabled during greedy rewriting.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the strictness level for the greedy rewrite driver.
+MLIR_CAPI_EXPORTED MlirGreedyRewriteStrictness
+mlirGreedyRewriteDriverConfigGetStrictness(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets the region simplification level.
+MLIR_CAPI_EXPORTED MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config);
+
+/// Gets whether constant CSE is enabled.
+MLIR_CAPI_EXPORTED bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+    MlirGreedyRewriteDriverConfig config);
 
 //===----------------------------------------------------------------------===//
 /// PDLPatternModule API
diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp
index 836f44fd7d4be..6908e9423e5a3 100644
--- a/mlir/lib/Bindings/Python/Rewrite.cpp
+++ b/mlir/lib/Bindings/Python/Rewrite.cpp
@@ -139,10 +139,96 @@ class PyFrozenRewritePatternSet {
   MlirFrozenRewritePatternSet set;
 };
 
+/// Owning Wrapper around a GreedyRewriteDriverConfig.
+class PyGreedyRewriteDriverConfig {
+public:
+  PyGreedyRewriteDriverConfig() 
+      : config(mlirGreedyRewriteDriverConfigCreate()) {}
+  PyGreedyRewriteDriverConfig(PyGreedyRewriteDriverConfig &&other) noexcept
+      : config(other.config) {
+    other.config.ptr = nullptr;
+  }
+  ~PyGreedyRewriteDriverConfig() {
+    if (config.ptr != nullptr)
+      mlirGreedyRewriteDriverConfigDestroy(config);
+  }
+  MlirGreedyRewriteDriverConfig get() { return config; }
+
+  void setMaxIterations(int64_t maxIterations) {
+    mlirGreedyRewriteDriverConfigSetMaxIterations(config, maxIterations);
+  }
+  
+  void setMaxNumRewrites(int64_t maxNumRewrites) {
+    mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, maxNumRewrites);
+  }
+  
+  void setUseTopDownTraversal(bool useTopDownTraversal) {
+    mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, useTopDownTraversal);
+  }
+  
+  void enableFolding(bool enable) {
+    mlirGreedyRewriteDriverConfigEnableFolding(config, enable);
+  }
+  
+  void setStrictness(MlirGreedyRewriteStrictness strictness) {
+    mlirGreedyRewriteDriverConfigSetStrictness(config, strictness);
+  }
+  
+  void setRegionSimplificationLevel(MlirGreedySimplifyRegionLevel level) {
+    mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(config, level);
+  }
+  
+  void enableConstantCSE(bool enable) {
+    mlirGreedyRewriteDriverConfigEnableConstantCSE(config, enable);
+  }
+
+  int64_t getMaxIterations() {
+    return mlirGreedyRewriteDriverConfigGetMaxIterations(config);
+  }
+  
+  int64_t getMaxNumRewrites() {
+    return mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config);
+  }
+  
+  bool getUseTopDownTraversal() {
+    return mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config);
+  }
+  
+  bool isFoldingEnabled() {
+    return mlirGreedyRewriteDriverConfigIsFoldingEnabled(config);
+  }
+  
+  MlirGreedyRewriteStrictness getStrictness() {
+    return mlirGreedyRewriteDriverConfigGetStrictness(config);
+  }
+  
+  MlirGreedySimplifyRegionLevel getRegionSimplificationLevel() {
+    return mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config);
+  }
+  
+  bool isConstantCSEEnabled() {
+    return mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config);
+  }
+
+private:
+  MlirGreedyRewriteDriverConfig config;
+};
+
 } // namespace
 
 /// Create the `mlir.rewrite` here.
 void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
+  // Enum definitions
+  nb::enum_<MlirGreedyRewriteStrictness>(m, "GreedyRewriteStrictness")
+      .value("ANY_OP", MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP)
+      .value("EXISTING_AND_NEW_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS)
+      .value("EXISTING_OPS", MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+      
+  nb::enum_<MlirGreedySimplifyRegionLevel>(m, "GreedySimplifyRegionLevel")
+      .value("DISABLED", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED)
+      .value("NORMAL", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL)
+      .value("AGGRESSIVE", MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE);
+
   nb::class_<MlirPatternRewriter>(m, "PatternRewriter");
   //----------------------------------------------------------------------------
   // Mapping of the PDLResultList and PDLModule
@@ -228,6 +314,38 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
           },
           nb::keep_alive<1, 3>());
 #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+  
+  nb::class_<PyGreedyRewriteDriverConfig>(m, "GreedyRewriteDriverConfig")
+      .def(nb::init<>(), "Create a greedy rewrite driver config with defaults")
+      .def("set_max_iterations", &PyGreedyRewriteDriverConfig::setMaxIterations,
+           "max_iterations"_a, "Set maximum number of iterations")
+      .def("set_max_num_rewrites", &PyGreedyRewriteDriverConfig::setMaxNumRewrites,
+           "max_num_rewrites"_a, "Set maximum number of rewrites per iteration")
+      .def("set_use_top_down_traversal", &PyGreedyRewriteDriverConfig::setUseTopDownTraversal,
+           "use_top_down"_a, "Set whether to use top-down traversal")
+      .def("enable_folding", &PyGreedyRewriteDriverConfig::enableFolding,
+           "enable"_a, "Enable or disable folding")
+      .def("set_strictness", &PyGreedyRewriteDriverConfig::setStrictness,
+           "strictness"_a, "Set rewrite strictness level")
+      .def("set_region_simplification_level", &PyGreedyRewriteDriverConfig::setRegionSimplificationLevel,
+           "level"_a, "Set region simplification level")
+      .def("enable_constant_cse", &PyGreedyRewriteDriverConfig::enableConstantCSE,
+           "enable"_a, "Enable or disable constant CSE")
+      .def("get_max_iterations", &PyGreedyRewriteDriverConfig::getMaxIterations,
+           "Get maximum number of iterations")
+      .def("get_max_num_rewrites", &PyGreedyRewriteDriverConfig::getMaxNumRewrites,
+           "Get maximum number of rewrites per iteration")
+      .def("get_use_top_down_traversal", &PyGreedyRewriteDriverConfig::getUseTopDownTraversal,
+           "Get whether top-down traversal is used")
+      .def("is_folding_enabled", &PyGreedyRewriteDriverConfig::isFoldingEnabled,
+           "Check if folding is enabled")
+      .def("get_strictness", &PyGreedyRewriteDriverConfig::getStrictness,
+           "Get rewrite strictness level")
+      .def("get_region_simplification_level", &PyGreedyRewriteDriverConfig::getRegionSimplificationLevel,
+           "Get region simplification level")
+      .def("is_constant_cse_enabled", &PyGreedyRewriteDriverConfig::isConstantCSEEnabled,
+           "Check if constant CSE is enabled");
+
   nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")
       .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR,
                    &PyFrozenRewritePatternSet::getCapsule)
diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp
index 8ee6308cadf83..0741308e19077 100644
--- a/mlir/lib/CAPI/Transforms/Rewrite.cpp
+++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp
@@ -289,18 +289,155 @@ void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) {
   op.ptr = nullptr;
 }
 
+//===----------------------------------------------------------------------===//
+/// GreedyRewriteDriverConfig API
+//===----------------------------------------------------------------------===//
+
+inline mlir::GreedyRewriteConfig *unwrap(MlirGreedyRewriteDriverConfig config) {
+  assert(config.ptr && "unexpected null config");
+  return static_cast<mlir::GreedyRewriteConfig *>(config.ptr);
+}
+
+inline MlirGreedyRewriteDriverConfig wrap(mlir::GreedyRewriteConfig *config) {
+  return {config};
+}
+
+MlirGreedyRewriteDriverConfig mlirGreedyRewriteDriverConfigCreate() {
+  return wrap(new mlir::GreedyRewriteConfig());
+}
+
+void mlirGreedyRewriteDriverConfigDestroy(MlirGreedyRewriteDriverConfig config) {
+  delete unwrap(config);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxIterations(
+    MlirGreedyRewriteDriverConfig config, int64_t maxIterations) {
+  unwrap(config)->setMaxIterations(maxIterations);
+}
+
+void mlirGreedyRewriteDriverConfigSetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config, int64_t maxNumRewrites) {
+  unwrap(config)->setMaxNumRewrites(maxNumRewrites);
+}
+
+void mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config, bool useTopDownTraversal) {
+  unwrap(config)->setUseTopDownTraversal(useTopDownTraversal);
+}
+
+void mlirGreedyRewriteDriverConfigEnableFolding(
+    MlirGreedyRewriteDriverConfig config, bool enable) {
+  unwrap(config)->enableFolding(enable);
+}
+
+void mlirGreedyRewriteDriverConfigSetStrictness(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedyRewriteStrictness strictness) {
+  mlir::GreedyRewriteStrictness cppStrictness;
+  switch (strictness) {
+  case MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP:
+    cppStrictness = mlir::GreedyRewriteStrictness::AnyOp;
+    break;
+  case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS:
+    cppStrictness = mlir::GreedyRewriteStrictness::ExistingAndNewOps;
+    break;
+  case MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS:
+    cppStrictness = mlir::GreedyRewriteStrictness::ExistingOps;
+    break;
+  }
+  unwrap(config)->setStrictness(cppStrictness);
+}
+
+void mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config,
+    MlirGreedySimplifyRegionLevel level) {
+  mlir::GreedySimplifyRegionLevel cppLevel;
+  switch (level) {
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Disabled;
+    break;
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Normal;
+    break;
+  case MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE:
+    cppLevel = mlir::GreedySimplifyRegionLevel::Aggressive;
+    break;
+  }
+  unwrap(config)->setRegionSimplificationLevel(cppLevel);
+}
+
+void mlirGreedyRewriteDriverConfigEnableConstantCSE(
+    MlirGreedyRewriteDriverConfig config, bool enable) {
+  unwrap(config)->enableConstantCSE(enable);
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxIterations(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getMaxIterations();
+}
+
+int64_t mlirGreedyRewriteDriverConfigGetMaxNumRewrites(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getMaxNumRewrites();
+}
+
+bool mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->getUseTopDownTraversal();
+}
+
+bool mlirGreedyRewriteDriverConfigIsFoldingEnabled(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->isFoldingEnabled();
+}
+
+MlirGreedyRewriteStrictness mlirGreedyRewriteDriverConfigGetStrictness(
+    MlirGreedyRewriteDriverConfig config) {
+  mlir::GreedyRewriteStrictness cppStrictness = unwrap(config)->getStrictness();
+  switch (cppStrictness) {
+  case mlir::GreedyRewriteStrictness::AnyOp:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_ANY_OP;
+  case mlir::GreedyRewriteStrictness::ExistingAndNewOps:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_AND_NEW_OPS;
+  case mlir::GreedyRewriteStrictness::ExistingOps:
+    return MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS;
+  }
+}
+
+MlirGreedySimplifyRegionLevel
+mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(
+    MlirGreedyRewriteDriverConfig config) {
+  mlir::GreedySimplifyRegionLevel cppLevel =
+      unwrap(config)->getRegionSimplificationLevel();
+  switch (cppLevel) {
+  case mlir::GreedySimplifyRegionLevel::Disabled:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_DISABLED;
+  case mlir::GreedySimplifyRegionLevel::Normal:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL;
+  case mlir::GreedySimplifyRegionLevel::Aggressive:
+    return MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_AGGRESSIVE;
+  }
+}
+
+bool mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(
+    MlirGreedyRewriteDriverConfig config) {
+  return unwrap(config)->isConstantCSEEnabled();
+}
+
 MlirLogicalResult
 mlirApplyPatternsAndFoldGreedily(MlirModule op,
                                  MlirFrozenRewritePatternSet patterns,
-                                 MlirGreedyRewriteDriverConfig) {
-  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+                                 MlirGreedyRewriteDriverConfig config) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+                                          *unwrap(config)));
 }
 
 MlirLogicalResult
 mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op,
                                        MlirFrozenRewritePatternSet patterns,
-                                       MlirGreedyRewriteDriverConfig) {
-  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns)));
+                                       MlirGreedyRewriteDriverConfig config) {
+  return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns),
+                                          *unwrap(config)));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/CAPI/rewrite.c b/mlir/test/CAPI/rewrite.c
index b33d225767046..0745eb496c1d7 100644
--- a/mlir/test/CAPI/rewrite.c
+++ b/mlir/test/CAPI/rewrite.c
@@ -534,6 +534,52 @@ void testReplaceUses(MlirContext ctx) {
   mlirModuleDestroy(module);
 }
 
+void testGreedyRewriteDriverConfig(MlirContext ctx) {
+  // CHECK-LABEL: @testGreedyRewriteDriverConfig
+  fprintf(stderr, "@testGreedyRewriteDriverConfig\n");
+
+  // Test config creation and destruction
+  MlirGreedyRewriteDriverConfig config = mlirGreedyRewriteDriverConfigCreate();
+
+  // Test all configuration setters
+  mlirGreedyRewriteDriverConfigSetMaxIterations(config, 5);
+  mlirGreedyRewriteDriverConfigSetMaxNumRewrites(config, 100);
+  mlirGreedyRewriteDriverConfigSetUseTopDownTraversal(config, true);
+  mlirGreedyRewriteDriverConfigEnableFolding(config, false);
+  mlirGreedyRewriteDriverConfigSetStrictness(
+      config, MLIR_GREEDY_REWRITE_STRICTNESS_EXISTING_OPS);
+  mlirGreedyRewriteDriverConfigSetRegionSimplificationLevel(
+      config, MLIR_GREEDY_SIMPLIFY_REGION_LEVEL_NORMAL);
+  mlirGreedyRewriteDriverConfigEnableConstantCSE(config, false);
+
+  // Test all configuration getters and verify values
+  // CHECK: MaxIterations: 5
+  fprintf(stderr, "MaxIterations: %ld\n",
+          mlirGreedyRewriteDriverConfigGetMaxIterations(config));
+  // CHECK: MaxNumRewrites: 100
+  fprintf(stderr, "MaxNumRewrites: %ld\n",
+          mlirGreedyRewriteDriverConfigGetMaxNumRewrites(config));
+  // CHECK: UseTopDownTraversal: 1
+  fprintf(stderr, "UseTopDownTraversal: %d\n",
+          mlirGreedyRewriteDriverConfigGetUseTopDownTraversal(config));
+  // CHECK: FoldingEnabled: 0
+  fprintf(stderr, "FoldingEnabled: %d\n",
+          mlirGreedyRewriteDriverConfigIsFoldingEnabled(config));
+  // CHECK: Strictness: 2
+  fprintf(stderr, "Strictness: %d\n",
+          mlirGreedyRewriteDriverConfigGetStrictness(config));
+  // CHECK: RegionSimplificationLevel: 1
+  fprintf(stderr, "RegionSimplificationLevel: %d\n",
+          mlirGreedyRewriteDriverConfigGetRegionSimplificationLevel(config));
+  // CHECK: ConstantCSEEnabled: 0
+  fprintf(stderr, "ConstantCSEEnabled: %d\n",
+          mlirGreedyRewriteDriverConfigIsConstantCSEEnabled(config));
+
+  // CHECK: Config test completed successfully
+  fprintf(stderr, "Config test completed successfully\n");
+  mlirGreedyRewriteDriverConfigDestroy(config);
+}
+
 int main(void) {
   MlirContext ctx = mlirContextCreate();
   mlirContextSetAllowUnregisteredDialects(ctx, true);
@@ -547,6 +593,7 @@ int main(void) {
   testMove(ctx);
   testOpModification(ctx);
   testReplaceUses(ctx);
+  testGreedyRewriteDriverConfig(ctx);
 
   mlirContextDestroy(ctx);
   return 0;
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 0000000000000..6f7deadd3cbba
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,107 @@
+# RUN: %PYTHON %s | FileCheck %s
+
+import gc
+fr...
[truncated]

"level"_a, "Set region simplification level")
.def("enable_constant_cse", &PyGreedyRewriteDriverConfig::enableConstantCSE,
"enable"_a, "Enable or disable constant CSE")
.def("get_max_iterations", &PyGreedyRewriteDriverConfig::getMaxIterations,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we turn these into properties? def_prop IIRC

Copy link
Contributor

@makslevental makslevental Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def_prop_rw and def_prop_ro

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, done.

config.region_simplification_level = GreedySimplifyRegionLevel.AGGRESSIVE
config.enable_constant_cse = True

# Test all getter methods and print results
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this does feel like a bit of overkill / should be covered c side and nanobind testing in general. But I argued with myself, it shows API too and makes it easier to find.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More tests is never bad 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants