@@ -14,11 +14,67 @@ class RewritePatternSet;
1414
1515namespace xegpu {
1616
17+ // / Options to control the XeGPU unrolling. Its main purpose is to
18+ // / provide a way to customize the native shape of the operation.
19+ struct UnrollOptions {
20+ // / Callback function that indicates whether vector unrolling should be
21+ // / attempted on the operation.
22+ using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
23+ FilterConstraintFnType filterConstraint = nullptr ;
24+ UnrollOptions &setFilterConstraint (FilterConstraintFnType constraint) {
25+ filterConstraint = std::move (constraint);
26+ return *this ;
27+ }
28+
29+ // / Function that computes the target shape for unrolling. It returns an
30+ // / optional vector of integers representing the shape. If it returns
31+ // / `std::nullopt`, unrolling is aborted for the given operation.
32+ using NativeShapeFnType =
33+ std::function<std::optional<SmallVector<int64_t >>(Operation *op)>;
34+ NativeShapeFnType nativeShape = nullptr ;
35+ UnrollOptions &setNativeShapeFn (NativeShapeFnType fn) {
36+ nativeShape = std::move (fn);
37+ return *this ;
38+ }
39+
40+ // / Function that converts a ShapedType (TensorDescType or VectorType)
41+ // / into the unrolled type based on the tileShape. It returns a vector of
42+ // / types representing the unrolled types for simplicity.
43+ using UnrolledTypeFnType = std::function<SmallVector<Type>(
44+ ShapedType type, ArrayRef<int64_t > tileShape)>;
45+ UnrolledTypeFnType getUnrolledTypes = nullptr ;
46+ UnrollOptions &setUnrolledTypesFn (UnrolledTypeFnType fn) {
47+ getUnrolledTypes = std::move (fn);
48+ return *this ;
49+ }
50+ };
51+
1752// / Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
1853void populateXeGPUFoldAliasOpsPatterns (RewritePatternSet &patterns);
54+
1955// / Appends patterns for XeGPU SIMT distribution into `patterns`.
2056void populateXeGPUSubgroupDistributePatterns (RewritePatternSet &patterns);
2157
58+ // / Collect a set of patterns to unroll xegpu operations to a smaller shapes.
59+ // / Users can control whether an operation to be unrolled or not, as well as
60+ // / its target shape via `options` structure. (via setting filterConstraint
61+ // / and nativeShape respectively, both of them are function refs taking `op` as
62+ // / input).
63+ // / An `op` is unrolled to the `targetShape` as follows, for each of its
64+ // / operands:
65+ // / 1. the unrolled type `unrolledType` and number of unrolled instances
66+ // / `numUnrolledInstances` are computed from the `targetShape`.
67+ // / 2. pack each operand. ExtractStridedSlice are created to break-up the
68+ // / vector operands. And BuiltinUnrealizedCastop are created to break-up
69+ // / the TensorDesc operands.
70+ // / 3. the original op is cloned `numUnrolledInstances` times, once for each
71+ // / result.
72+ // / 4. unpack the results. InsertStridedSlice are inserted for VectorType
73+ // / result, and BuiltinUnrealizedCastOp are inserted for TensorDescType result
74+ // / to re-assemble the slices into the original shape.
75+ void populateXeGPUUnrollPatterns (RewritePatternSet &patterns,
76+ const UnrollOptions &options);
77+
2278} // namespace xegpu
2379} // namespace mlir
2480
0 commit comments