Skip to content

Commit 834b8b7

Browse files
[acc] Add acc.specialized_routine attribute (#170766)
Introduce a new attribute `acc.specialized_routine` to mark functions that have been specialized from a host function marked with `acc.routine_info`. The new attribute captures: - A SymbolRefAttr referencing the original `acc.routine` operation - The parallelism level via the new `ParLevel` enum - The original function name (since specialized functions may be renamed) Example - before specialization: ``` acc.routine @routine_gang func(@foo) gang acc.routine @routine_vector func(@foo) vector func.func @foo() attributes { acc.routine_info = #acc.routine_info<[@routine_gang, @routine_vector]> } { ... } ``` After specialization, there are three functions: the original function and two specialized versions (one per parallelism level): ``` acc.routine @routine_gang func(@foo) gang acc.routine @routine_vector func(@foo) vector // Original function (unchanged) func.func @foo() attributes { acc.routine_info = #acc.routine_info<[@routine_gang, @routine_vector]> } { ... } // Specialized for gang parallelism func.func @foo_gang() attributes { acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "foo"> } { ... } // Specialized for vector parallelism func.func @foo_vector() attributes { acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "foo"> } { ... } ```
1 parent 02ca50e commit 834b8b7

File tree

4 files changed

+148
-4
lines changed

4 files changed

+148
-4
lines changed

mlir/include/mlir/Dialect/OpenACC/OpenACC.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,25 @@ static constexpr StringLiteral getDeclareActionAttrName() {
177177
}
178178

179179
static constexpr StringLiteral getRoutineInfoAttrName() {
180-
return StringLiteral("acc.routine_info");
180+
return RoutineInfoAttr::name;
181181
}
182182

183-
/// Used to check whether the current operation is an `acc routine`
184-
inline bool isAccRoutineOp(mlir::Operation *op) {
183+
static constexpr StringLiteral getSpecializedRoutineAttrName() {
184+
return SpecializedRoutineAttr::name;
185+
}
186+
187+
/// Used to check whether the current operation is marked with
188+
/// `acc routine`. The operation passed in should be a function.
189+
inline bool isAccRoutine(mlir::Operation *op) {
185190
return op->hasAttr(mlir::acc::getRoutineInfoAttrName());
186191
}
187192

193+
/// Used to check whether this is a specialized accelerator version of
194+
/// `acc routine` function.
195+
inline bool isSpecializedAccRoutine(mlir::Operation *op) {
196+
return op->hasAttr(mlir::acc::getSpecializedRoutineAttrName());
197+
}
198+
188199
static constexpr StringLiteral getFromDefaultClauseAttrName() {
189200
return StringLiteral("acc.from_default");
190201
}

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,32 @@ def OpenACC_LoopParMode : I32EnumAttr<
152152
let genSpecializedAttr = 0;
153153
}
154154

155+
// Parallelism level (gang/worker/vector/seq).
156+
// GangDim1 is the default gang level (equivalent to just "gang").
157+
// GangDim2/GangDim3 are for gang(dim:2) and gang(dim:3).
158+
def OpenACC_ParLevelSeq : I32EnumAttrCase<"seq", 0>;
159+
def OpenACC_ParLevelGangDim1 : I32EnumAttrCase<"gang_dim1", 1>;
160+
def OpenACC_ParLevelGangDim2 : I32EnumAttrCase<"gang_dim2", 2>;
161+
def OpenACC_ParLevelGangDim3 : I32EnumAttrCase<"gang_dim3", 3>;
162+
def OpenACC_ParLevelWorker : I32EnumAttrCase<"worker", 4>;
163+
def OpenACC_ParLevelVector : I32EnumAttrCase<"vector", 5>;
164+
165+
def OpenACC_ParLevel : I32EnumAttr<"ParLevel",
166+
"Parallelism level (gang/worker/vector/seq)",
167+
[OpenACC_ParLevelSeq,
168+
OpenACC_ParLevelGangDim1, OpenACC_ParLevelGangDim2,
169+
OpenACC_ParLevelGangDim3,
170+
OpenACC_ParLevelWorker, OpenACC_ParLevelVector]> {
171+
let genSpecializedAttr = 0;
172+
let cppNamespace = "::mlir::acc";
173+
}
174+
175+
def OpenACC_ParLevelAttr : EnumAttr<OpenACC_Dialect,
176+
OpenACC_ParLevel,
177+
"par_level"> {
178+
let assemblyFormat = [{ ```<` $value `>` }];
179+
}
180+
155181
def OpenACC_PrivateRecipe : I32EnumAttrCase<"private_recipe", 0>;
156182
def OpenACC_FirstprivateRecipe : I32EnumAttrCase<"firstprivate_recipe", 1>;
157183
def OpenACC_ReductionRecipe : I32EnumAttrCase<"reduction_recipe", 2>;
@@ -3349,6 +3375,58 @@ def RoutineInfoAttr : OpenACC_Attr<"RoutineInfo", "routine_info"> {
33493375
let assemblyFormat = "`<` `[` `` $accRoutines `]` `>`";
33503376
}
33513377

3378+
def SpecializedRoutineAttr : OpenACC_Attr<"SpecializedRoutine",
3379+
"specialized_routine"> {
3380+
let summary = "Marks a specialized device version of an acc routine";
3381+
3382+
let description = [{
3383+
This attribute is attached to a function that was specialized from a host
3384+
function marked with `acc.routine_info`. It captures the parallelism level,
3385+
a reference to the original `acc.routine` operation, and the original
3386+
function name (since the specialized function may be renamed).
3387+
3388+
Example - before specialization:
3389+
```mlir
3390+
acc.routine @routine_gang func(@foo) gang
3391+
acc.routine @routine_vector func(@foo) vector
3392+
3393+
func.func @foo() attributes {
3394+
acc.routine_info = #acc.routine_info<[@routine_gang, @routine_vector]>
3395+
} { ... }
3396+
```
3397+
3398+
After specialization, there are three functions: the original function and
3399+
two specialized versions (one per parallelism level):
3400+
```mlir
3401+
acc.routine @routine_gang func(@foo) gang
3402+
acc.routine @routine_vector func(@foo) vector
3403+
3404+
// Original function (unchanged)
3405+
func.func @foo() attributes {
3406+
acc.routine_info = #acc.routine_info<[@routine_gang, @routine_vector]>
3407+
} { ... }
3408+
3409+
// Specialized for gang parallelism
3410+
func.func @foo_gang() attributes {
3411+
acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "foo">
3412+
} { ... }
3413+
3414+
// Specialized for vector parallelism
3415+
func.func @foo_vector() attributes {
3416+
acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "foo">
3417+
} { ... }
3418+
```
3419+
}];
3420+
3421+
let parameters = (ins
3422+
"SymbolRefAttr":$routine,
3423+
"ParLevelAttr":$level,
3424+
"StringAttr":$funcName
3425+
);
3426+
3427+
let assemblyFormat = "`<` $routine `,` $level `,` $funcName `>`";
3428+
}
3429+
33523430
//===----------------------------------------------------------------------===//
33533431
// 2.14.1. Init Directive
33543432
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/OpenACC/Transforms/ACCImplicitDeclare.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,9 @@ class ACCImplicitDeclare
360360
accOp.getRegion(), globalsToAccDeclare, accSupport, symTab);
361361
})
362362
.Case<FunctionOpInterface>([&](auto func) {
363-
if (acc::isAccRoutineOp(func) && !func.isExternal())
363+
if ((acc::isAccRoutine(func) ||
364+
acc::isSpecializedAccRoutine(func)) &&
365+
!func.isExternal())
364366
collectGlobalsFromDeviceRegion(func.getFunctionBody(),
365367
globalsToAccDeclare, accSupport,
366368
symTab);

mlir/test/Dialect/OpenACC/ops.mlir

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,6 +1810,59 @@ acc.routine @acc_func_rout9 func(@acc_func) bind("acc_func_gpu_gang_dim1") gang(
18101810

18111811
// -----
18121812

1813+
// Test acc.specialized_routine attribute for specialized device functions
1814+
acc.routine @routine_seq func(@device_func_seq) seq
1815+
acc.routine @routine_gang func(@device_func_gang) gang
1816+
acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64)
1817+
acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64)
1818+
acc.routine @routine_worker func(@device_func_worker) worker
1819+
acc.routine @routine_vector func(@device_func_vector) vector
1820+
1821+
func.func @device_func_seq() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">} {
1822+
return
1823+
}
1824+
1825+
func.func @device_func_gang() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">} {
1826+
return
1827+
}
1828+
1829+
func.func @device_func_gang_dim2() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">} {
1830+
return
1831+
}
1832+
1833+
func.func @device_func_gang_dim3() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">} {
1834+
return
1835+
}
1836+
1837+
func.func @device_func_worker() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">} {
1838+
return
1839+
}
1840+
1841+
func.func @device_func_vector() attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">} {
1842+
return
1843+
}
1844+
1845+
// CHECK: acc.routine @routine_seq func(@device_func_seq) seq
1846+
// CHECK: acc.routine @routine_gang func(@device_func_gang) gang
1847+
// CHECK: acc.routine @routine_gang_dim2 func(@device_func_gang_dim2) gang(dim: 2 : i64)
1848+
// CHECK: acc.routine @routine_gang_dim3 func(@device_func_gang_dim3) gang(dim: 3 : i64)
1849+
// CHECK: acc.routine @routine_worker func(@device_func_worker) worker
1850+
// CHECK: acc.routine @routine_vector func(@device_func_vector) vector
1851+
// CHECK-LABEL: func.func @device_func_seq()
1852+
// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_seq, <seq>, "host_func_seq">}
1853+
// CHECK-LABEL: func.func @device_func_gang()
1854+
// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang, <gang_dim1>, "host_func_gang">}
1855+
// CHECK-LABEL: func.func @device_func_gang_dim2()
1856+
// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim2, <gang_dim2>, "host_func_gang_dim2">}
1857+
// CHECK-LABEL: func.func @device_func_gang_dim3()
1858+
// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_gang_dim3, <gang_dim3>, "host_func_gang_dim3">}
1859+
// CHECK-LABEL: func.func @device_func_worker()
1860+
// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_worker, <worker>, "host_func_worker">}
1861+
// CHECK-LABEL: func.func @device_func_vector()
1862+
// CHECK: attributes {acc.specialized_routine = #acc.specialized_routine<@routine_vector, <vector>, "host_func_vector">}
1863+
1864+
// -----
1865+
18131866
func.func @acc_func() -> () {
18141867
"test.openacc_dummy_op"() {acc.declare_action = #acc.declare_action<postAlloc = @_QMacc_declareFacc_declare_allocateEa_acc_declare_update_desc_post_alloc>} : () -> ()
18151868
return

0 commit comments

Comments
 (0)