|
25 | 25 | #include "mlir/Pass/PassManager.h" |
26 | 26 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
27 | 27 | #include "shardy/dialect/sdy/ir/utils.h" |
| 28 | +#include "src/enzyme_ad/jax/CheckedRewrite.h" |
28 | 29 | #include "src/enzyme_ad/jax/Dialect/Dialect.h" |
29 | 30 | #include "src/enzyme_ad/jax/Dialect/Ops.h" |
30 | 31 | #include "src/enzyme_ad/jax/Passes/EnzymeHLOPatterns.h" |
@@ -306,86 +307,6 @@ class StaticSlice { |
306 | 307 | } |
307 | 308 | }; |
308 | 309 |
|
309 | | -LogicalResult failIfDynamicShape(Operation *op, PatternRewriter &rewriter) { |
310 | | - for (auto type : op->getResultTypes()) { |
311 | | - auto rType = dyn_cast<RankedTensorType>(type); |
312 | | - if (!rType || !rType.hasStaticShape()) |
313 | | - return rewriter.notifyMatchFailure( |
314 | | - op, "unsupported dynamic shape for output."); |
315 | | - } |
316 | | - |
317 | | - for (auto type : op->getOperandTypes()) { |
318 | | - auto rType = dyn_cast<RankedTensorType>(type); |
319 | | - if (!rType || !rType.hasStaticShape()) |
320 | | - return rewriter.notifyMatchFailure( |
321 | | - op, "unsupported dynamic shape for input."); |
322 | | - } |
323 | | - |
324 | | - return success(); |
325 | | -} |
326 | | - |
327 | | -LogicalResult failIfFuncOpInterfaceHasAttr(Operation *op, StringRef attrName, |
328 | | - PatternRewriter &rewriter) { |
329 | | - if (auto func = op->getParentOfType<FunctionOpInterface>()) { |
330 | | - if (func->hasAttrOfType<UnitAttr>(attrName)) |
331 | | - return rewriter.notifyMatchFailure(op, "disabled by attribute."); |
332 | | - } |
333 | | - |
334 | | - return success(); |
335 | | -} |
336 | | - |
337 | | -static constexpr StringRef kDisablePatternAttrName = |
338 | | - "enzymexla.disable_hlo_opts"; |
339 | | - |
340 | | -template <typename OpTy, typename Child> |
341 | | -struct CheckedOpRewritePattern : public OpRewritePattern<OpTy> { |
342 | | - using Base = OpRewritePattern<OpTy>; |
343 | | - using Base::Base; |
344 | | - |
345 | | - LogicalResult |
346 | | - matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override final { |
347 | | - LogicalResult res = |
348 | | - failIfFuncOpInterfaceHasAttr(op, kDisablePatternAttrName, rewriter); |
349 | | - if (res.failed()) |
350 | | - return res; |
351 | | - |
352 | | - if (!((Child *)this)->supportsDynamicShapes()) { |
353 | | - LogicalResult res = failIfDynamicShape(op, rewriter); |
354 | | - if (res.failed()) |
355 | | - return res; |
356 | | - } |
357 | | - |
358 | | - return ((Child *)this)->matchAndRewriteImpl(op, rewriter); |
359 | | - } |
360 | | - |
361 | | - bool supportsDynamicShapes() { return false; } |
362 | | -}; |
363 | | - |
364 | | -template <template <typename> class TraitType, typename Child> |
365 | | -struct CheckedOpTraitRewritePattern : public OpTraitRewritePattern<TraitType> { |
366 | | - using Base = OpTraitRewritePattern<TraitType>; |
367 | | - using Base::Base; |
368 | | - |
369 | | - LogicalResult |
370 | | - matchAndRewrite(Operation *op, |
371 | | - PatternRewriter &rewriter) const override final { |
372 | | - LogicalResult res = |
373 | | - failIfFuncOpInterfaceHasAttr(op, kDisablePatternAttrName, rewriter); |
374 | | - if (res.failed()) |
375 | | - return res; |
376 | | - |
377 | | - if (!((Child *)this)->supportsDynamicShapes()) { |
378 | | - auto res = failIfDynamicShape(op, rewriter); |
379 | | - if (res.failed()) |
380 | | - return res; |
381 | | - } |
382 | | - |
383 | | - return ((Child *)this)->matchAndRewriteImpl(op, rewriter); |
384 | | - } |
385 | | - |
386 | | - bool supportsDynamicShapes() { return false; } |
387 | | -}; |
388 | | - |
389 | 310 | template <typename OpTy, typename Child> |
390 | 311 | struct NoNanCheckedOpRewritePattern |
391 | 312 | : public CheckedOpRewritePattern<OpTy, Child> { |
|
0 commit comments