Skip to content

Commit 6ca9a30

Browse files
authored
[flang][cuda] Update stream operand type for cuf.kernel_launch op (llvm#135222)
1 parent 6493345 commit 6ca9a30

File tree

3 files changed

+15
-19
lines changed

3 files changed

+15
-19
lines changed

flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,24 +197,16 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
197197
```
198198
}];
199199

200-
let arguments = (ins
201-
SymbolRefAttr:$callee,
202-
I32:$grid_x,
203-
I32:$grid_y,
204-
I32:$grid_z,
205-
I32:$block_x,
206-
I32:$block_y,
207-
I32:$block_z,
208-
Optional<I32>:$bytes,
209-
Optional<I32>:$stream,
210-
Variadic<AnyType>:$args,
211-
OptionalAttr<DictArrayAttr>:$arg_attrs,
212-
OptionalAttr<DictArrayAttr>:$res_attrs
213-
);
200+
let arguments = (ins SymbolRefAttr:$callee, I32:$grid_x, I32:$grid_y,
201+
I32:$grid_z, I32:$block_x, I32:$block_y, I32:$block_z,
202+
Optional<I32>:$bytes, Optional<AnyIntegerType>:$stream,
203+
Variadic<AnyType>:$args, OptionalAttr<DictArrayAttr>:$arg_attrs,
204+
OptionalAttr<DictArrayAttr>:$res_attrs);
214205

215206
let assemblyFormat = [{
216207
$callee `<` `<` `<` $grid_x `,` $grid_y `,` $grid_z `,`$block_x `,`
217-
$block_y `,` $block_z ( `,` $bytes^ ( `,` $stream^ )? )? `>` `>` `>`
208+
$block_y `,` $block_z
209+
( `,` $bytes^ ( `,` $stream^ `:` type($stream) )? )? `>` `>` `>`
218210
`` `(` $args `)` ( `:` `(` type($args)^ `)` )? attr-dict
219211
}];
220212

flang/lib/Lower/ConvertCall.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,10 +588,8 @@ Fortran::lower::genCallOpAndResult(
588588

589589
mlir::Value stream; // stream is optional.
590590
if (caller.getCallDescription().chevrons().size() > 3)
591-
stream = builder.createConvert(
592-
loc, i32Ty,
593-
fir::getBase(converter.genExprValue(
594-
caller.getCallDescription().chevrons()[3], stmtCtx)));
591+
stream = fir::getBase(converter.genExprValue(
592+
caller.getCallDescription().chevrons()[3], stmtCtx));
595593

596594
builder.create<cuf::KernelLaunchOp>(
597595
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z,

flang/test/Lower/CUDA/cuda-kernel-calls.cuf

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ contains
1515

1616
subroutine host()
1717
real, device :: a
18+
integer(8) :: stream
19+
1820
! CHECK-LABEL: func.func @_QMtest_callPhost()
1921
! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
2022

@@ -51,6 +53,10 @@ contains
5153

5254
call dev_kernel1<<<*, 32>>>(a)
5355
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}>>>(%{{.*}})
56+
57+
call dev_kernel1<<<*,32,0,stream>>>(a)
58+
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : i64>>>(%{{.*}}) : (!fir.ref<f32>)
59+
5460
end
5561

5662
end

0 commit comments

Comments
 (0)