Skip to content

Commit fc5ad5f

Browse files
committed
Compute up to one missing dim
1 parent 3fbd025 commit fc5ad5f

File tree

1 file changed

+66
-14
lines changed

1 file changed

+66
-14
lines changed

flang/runtime/CUDA/kernel.cpp

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,17 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
2525
blockDim.x = blockX;
2626
blockDim.y = blockY;
2727
blockDim.z = blockZ;
28-
bool gridIsStar = (gridX < 0); // <<<*, block>>> syntax was used.
29-
if (gridIsStar) {
28+
unsigned nbNegGridDim{0};
29+
if (gridX < 0) {
30+
++nbNegGridDim;
31+
}
32+
if (gridY < 0) {
33+
++nbNegGridDim;
34+
}
35+
if (gridZ < 0) {
36+
++nbNegGridDim;
37+
}
38+
if (nbNegGridDim == 1) {
3039
int maxBlocks, nbBlocks, dev, multiProcCount;
3140
cudaError_t err1, err2;
3241
nbBlocks = blockDim.x * blockDim.y * blockDim.z;
@@ -35,18 +44,35 @@ void RTDEF(CUFLaunchKernel)(const void *kernel, intptr_t gridX, intptr_t gridY,
3544
&multiProcCount, cudaDevAttrMultiProcessorCount, dev);
3645
err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
3746
&maxBlocks, kernel, nbBlocks, smem);
38-
if (err1 == cudaSuccess && err2 == cudaSuccess)
47+
if (err1 == cudaSuccess && err2 == cudaSuccess) {
3948
maxBlocks = multiProcCount * maxBlocks;
49+
}
4050
if (maxBlocks > 0) {
41-
if (gridDim.y > 0)
51+
if (gridDim.x > 0) {
52+
maxBlocks = maxBlocks / gridDim.x;
53+
}
54+
if (gridDim.y > 0) {
4255
maxBlocks = maxBlocks / gridDim.y;
43-
if (gridDim.z > 0)
56+
}
57+
if (gridDim.z > 0) {
4458
maxBlocks = maxBlocks / gridDim.z;
45-
if (maxBlocks < 1)
59+
}
60+
if (maxBlocks < 1) {
4661
maxBlocks = 1;
47-
if (gridIsStar)
62+
}
63+
if (gridX < 0) {
4864
gridDim.x = maxBlocks;
65+
}
66+
if (gridY < 0) {
67+
gridDim.y = maxBlocks;
68+
}
69+
if (gridZ < 0) {
70+
gridDim.z = maxBlocks;
71+
}
4972
}
73+
} else if (nbNegGridDim > 1) {
74+
Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
75+
terminator.Crash("Too many invalid grid dimensions");
5076
}
5177
cudaStream_t stream = 0; // TODO stream managment
5278
CUDA_REPORT_IF_ERROR(
@@ -64,8 +90,17 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
6490
config.blockDim.x = blockX;
6591
config.blockDim.y = blockY;
6692
config.blockDim.z = blockZ;
67-
bool gridIsStar = (gridX < 0); // <<<*, block>>> syntax was used.
68-
if (gridIsStar) {
93+
unsigned nbNegGridDim{0};
94+
if (gridX < 0) {
95+
++nbNegGridDim;
96+
}
97+
if (gridY < 0) {
98+
++nbNegGridDim;
99+
}
100+
if (gridZ < 0) {
101+
++nbNegGridDim;
102+
}
103+
if (nbNegGridDim == 1) {
69104
int maxBlocks, nbBlocks, dev, multiProcCount;
70105
cudaError_t err1, err2;
71106
nbBlocks = config.blockDim.x * config.blockDim.y * config.blockDim.z;
@@ -74,18 +109,35 @@ void RTDEF(CUFLaunchClusterKernel)(const void *kernel, intptr_t clusterX,
74109
&multiProcCount, cudaDevAttrMultiProcessorCount, dev);
75110
err2 = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
76111
&maxBlocks, kernel, nbBlocks, smem);
77-
if (err1 == cudaSuccess && err2 == cudaSuccess)
112+
if (err1 == cudaSuccess && err2 == cudaSuccess) {
78113
maxBlocks = multiProcCount * maxBlocks;
114+
}
79115
if (maxBlocks > 0) {
80-
if (config.gridDim.y > 0)
116+
if (config.gridDim.x > 0) {
117+
maxBlocks = maxBlocks / config.gridDim.x;
118+
}
119+
if (config.gridDim.y > 0) {
81120
maxBlocks = maxBlocks / config.gridDim.y;
82-
if (config.gridDim.z > 0)
121+
}
122+
if (config.gridDim.z > 0) {
83123
maxBlocks = maxBlocks / config.gridDim.z;
84-
if (maxBlocks < 1)
124+
}
125+
if (maxBlocks < 1) {
85126
maxBlocks = 1;
86-
if (gridIsStar)
127+
}
128+
if (gridX < 0) {
87129
config.gridDim.x = maxBlocks;
130+
}
131+
if (gridY < 0) {
132+
config.gridDim.y = maxBlocks;
133+
}
134+
if (gridZ < 0) {
135+
config.gridDim.z = maxBlocks;
136+
}
88137
}
138+
} else if (nbNegGridDim > 1) {
139+
Fortran::runtime::Terminator terminator{__FILE__, __LINE__};
140+
terminator.Crash("Too many invalid grid dimensions");
89141
}
90142
config.dynamicSmemBytes = smem;
91143
config.stream = 0; // TODO stream managment

0 commit comments

Comments
 (0)