Skip to content

Commit 825eefe

Browse files
authored
[flang][cuda] Accept scalar expression for bytes in kernel call (#165040)
1 parent 8c29bce commit 825eefe

File tree

5 files changed

+28
-5
lines changed

5 files changed

+28
-5
lines changed

flang/include/flang/Parser/parse-tree.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3274,13 +3274,13 @@ struct FunctionReference {
32743274
// R1521 call-stmt -> CALL procedure-designator [ chevrons ]
32753275
// [( [actual-arg-spec-list] )]
32763276
// (CUDA) chevrons -> <<< * | scalar-expr, scalar-expr [,
3277-
// scalar-int-expr [, scalar-int-expr ] ] >>>
3277+
// scalar-expr [, scalar-int-expr ] ] >>>
32783278
struct CallStmt {
32793279
BOILERPLATE(CallStmt);
32803280
WRAPPER_CLASS(StarOrExpr, std::optional<ScalarExpr>);
32813281
struct Chevrons {
32823282
TUPLE_CLASS_BOILERPLATE(Chevrons);
3283-
std::tuple<StarOrExpr, ScalarExpr, std::optional<ScalarIntExpr>,
3283+
std::tuple<StarOrExpr, ScalarExpr, std::optional<ScalarExpr>,
32843284
std::optional<ScalarIntExpr>>
32853285
t;
32863286
};

flang/lib/Parser/program-parsers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ constexpr auto starOrExpr{
484484
applyFunction(presentOptional<ScalarExpr>, scalarExpr))};
485485
TYPE_PARSER(extension<LanguageFeature::CUDA>(
486486
"<<<" >> construct<CallStmt::Chevrons>(starOrExpr, ", " >> scalarExpr,
487-
maybe("," >> scalarIntExpr), maybe("," >> scalarIntExpr)) /
487+
maybe("," >> scalarExpr), maybe("," >> scalarIntExpr)) /
488488
">>>"))
489489
constexpr auto actualArgSpecList{optionalList(actualArgSpec)};
490490
TYPE_CONTEXT_PARSER("CALL statement"_en_US,

flang/test/Lower/CUDA/cuda-kernel-calls.cuf

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ contains
1616
subroutine host()
1717
real, device :: a
1818
integer(8) :: stream
19+
integer(4) :: nbytes
1920

2021
! CHECK-LABEL: func.func @_QMtest_callPhost()
2122
! CHECK: %[[A:.*]]:2 = hlfir.declare %{{.*}} {data_attr = #cuf.cuda<device>, uniq_name = "_QMtest_callFhostEa"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
@@ -57,6 +58,10 @@ contains
5758
call dev_kernel1<<<*,32,0,stream>>>(a)
5859
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%c-1{{.*}}, %c1{{.*}}, %c1{{.*}}, %c32{{.*}}, %c1{{.*}}, %c1{{.*}}, %c0{{.*}}, %{{.*}} : !fir.ref<i64>>>>(%{{.*}}) : (!fir.ref<f32>)
5960

61+
call dev_kernel1<<<*, 32, 0.8 * nbytes>>>(a)
62+
! CHECK: %[[MUL:.*]] = arith.mulf %{{.*}}, %{{.*}} fastmath<contract> : f32
63+
! CHECK: %[[BYTES:.*]] = fir.convert %[[MUL]] : (f32) -> i32
64+
! CHECK: cuf.kernel_launch @_QMtest_callPdev_kernel1<<<%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[BYTES]]>>>(%{{.*}}) : (!fir.ref<f32>)
6065
end
6166

6267
end

flang/test/Parser/cuf-sanity-common

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ module m
4343
call globalsub<<<1, 2>>>
4444
call globalsub<<<1, 2, 3>>>
4545
call globalsub<<<1, 2, 3, 4>>>
46+
call globalsub<<<1, 2, 0.9*10, 4>>>
4647
call globalsub<<<*,5>>>
4748
allocate(pa(32), pinned = isPinned)
4849
end subroutine

flang/test/Parser/cuf-sanity-tree.CUF

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ include "cuf-sanity-common"
178178
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
179179
!CHECK: | | | | | | Scalar -> Expr = '2_4'
180180
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
181-
!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
181+
!CHECK: | | | | | | Scalar -> Expr = '3_4'
182182
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
183183
!CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> CallStmt = 'CALL globalsub<<<1_4,2_4,3_4,4_4>>>()'
184184
!CHECK: | | | | | Call
@@ -188,10 +188,27 @@ include "cuf-sanity-common"
188188
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
189189
!CHECK: | | | | | | Scalar -> Expr = '2_4'
190190
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
191-
!CHECK: | | | | | | Scalar -> Integer -> Expr = '3_4'
191+
!CHECK: | | | | | | Scalar -> Expr = '3_4'
192192
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '3'
193193
!CHECK: | | | | | | Scalar -> Integer -> Expr = '4_4'
194194
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '4'
195+
!CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> CallStmt = 'CALL globalsub<<<1_4,2_4,9._4,4_4>>>()'
196+
!CHECK: | | | | | Call
197+
!CHECK: | | | | | | ProcedureDesignator -> Name = 'globalsub'
198+
!CHECK: | | | | | Chevrons
199+
!CHECK: | | | | | | StarOrExpr -> Scalar -> Expr = '1_4'
200+
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '1'
201+
!CHECK: | | | | | | Scalar -> Expr = '2_4'
202+
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '2'
203+
!CHECK: | | | | | | Scalar -> Expr = '9._4'
204+
!CHECK: | | | | | | | Multiply
205+
!CHECK: | | | | | | | | Expr = '8.9999997615814208984375e-1_4'
206+
!CHECK: | | | | | | | | | LiteralConstant -> RealLiteralConstant
207+
!CHECK: | | | | | | | | | | Real = '0.9'
208+
!CHECK: | | | | | | | | Expr = '10_4'
209+
!CHECK: | | | | | | | | | LiteralConstant -> IntLiteralConstant = '10'
210+
!CHECK: | | | | | | Scalar -> Integer -> Expr = '4_4'
211+
!CHECK: | | | | | | | LiteralConstant -> IntLiteralConstant = '4'
195212
!CHECK: | | | | ExecutionPartConstruct -> ExecutableConstruct -> ActionStmt -> AllocateStmt
196213
!CHECK: | | | | | Allocation
197214
!CHECK: | | | | | | AllocateObject = 'pa'

0 commit comments

Comments
 (0)