Skip to content

Commit 3bcfb0f

Browse files
committed
Changedoconcurrent inside cuf kernel directive induction variable type from i32 to index to be the same as regular do loops inside cuf kernel.
1 parent 2bbe30b commit 3bcfb0f

File tree

2 files changed

+12
-17
lines changed

2 files changed

+12
-17
lines changed

flang/lib/Lower/Bridge.cpp

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3150,10 +3150,9 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31503150
loc, 1); // Use index type directly
31513151

31523152
// Ensure lb, ub, and step are of index type using fir.convert
3153-
mlir::Type indexType = builder->getIndexType();
3154-
lb = builder->create<fir::ConvertOp>(loc, indexType, lb);
3155-
ub = builder->create<fir::ConvertOp>(loc, indexType, ub);
3156-
step = builder->create<fir::ConvertOp>(loc, indexType, step);
3153+
lb = builder->create<fir::ConvertOp>(loc, idxTy, lb);
3154+
ub = builder->create<fir::ConvertOp>(loc, idxTy, ub);
3155+
step = builder->create<fir::ConvertOp>(loc, idxTy, step);
31573156

31583157
lbs.push_back(lb);
31593158
ubs.push_back(ub);
@@ -3163,18 +3162,14 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31633162

31643163
// Handle induction variable
31653164
mlir::Value ivValue = getSymbolAddress(*name.symbol);
3166-
std::size_t ivTypeSize = name.symbol->size();
3167-
if (ivTypeSize == 0)
3168-
llvm::report_fatal_error("unexpected induction variable size");
3169-
mlir::Type ivTy = builder->getIntegerType(ivTypeSize * 8);
31703165

31713166
if (!ivValue) {
31723167
// DO CONCURRENT induction variables are not mapped yet since they are
31733168
// local to the DO CONCURRENT scope.
31743169
mlir::OpBuilder::InsertPoint insPt = builder->saveInsertionPoint();
31753170
builder->setInsertionPointToStart(builder->getAllocaBlock());
31763171
ivValue = builder->createTemporaryAlloc(
3177-
loc, ivTy, toStringRef(name.symbol->name()));
3172+
loc, idxTy, toStringRef(name.symbol->name()));
31783173
builder->restoreInsertionPoint(insPt);
31793174
}
31803175

@@ -3186,7 +3181,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
31863181
// Bind the symbol to the declared variable
31873182
bindSymbol(*name.symbol, ivValue);
31883183
ivValues.push_back(ivValue);
3189-
ivTypes.push_back(ivTy);
3184+
ivTypes.push_back(idxTy);
31903185
ivLocs.push_back(loc);
31913186
}
31923187
} else {

flang/test/Lower/CUDA/cuda-doconc.cuf

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ subroutine doconc1
1515
end
1616

1717
! CHECK: func.func @_QPdoconc1() {
18-
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
18+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc1Ei"} : (!fir.ref<index>) -> (!fir.ref<index>, !fir.ref<index>)
1919
! CHECK: cuf.kernel<<<*, *>>>
20-
! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<i32>
20+
! CHECK: %{{.*}} = fir.load %[[DECL]]#0 : !fir.ref<index>
2121

2222
subroutine doconc2
2323
integer :: i, j, m, n
@@ -32,8 +32,8 @@ subroutine doconc2
3232
end
3333

3434
! CHECK: func.func @_QPdoconc2() {
35-
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
36-
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
37-
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : i32, %arg1 : i32) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
38-
! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref<i32>
39-
! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref<i32>
35+
! CHECK: %[[DECLI:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ei"} : (!fir.ref<index>) -> (!fir.ref<index>, !fir.ref<index>)
36+
! CHECK: %[[DECLJ:.*]]:2 = hlfir.declare %{{.*}}#0 {uniq_name = "_QFdoconc2Ej"} : (!fir.ref<index>) -> (!fir.ref<index>, !fir.ref<index>)
37+
! CHECK: cuf.kernel<<<*, *>>> (%arg0 : index, %arg1 : index) = (%{{.*}}, %{{.*}} : index, index) to (%{{.*}}, %{{.*}} : index, index) step (%{{.*}}, %{{.*}} : index, index) {
38+
! CHECK: %{{.*}} = fir.load %[[DECLI]]#0 : !fir.ref<index>
39+
! CHECK: %{{.*}} = fir.load %[[DECLJ]]#0 : !fir.ref<index>

0 commit comments

Comments
 (0)