@@ -213,15 +213,15 @@ end
213
213
us:: DataLayouts.UniversalSize ,
214
214
n_max_threads:: Integer ,
215
215
)
216
- (Nij, _ , _, _, Nh) = DataLayouts. universal_size (us)
216
+ (Ni, Nj , _, _, Nh) = DataLayouts. universal_size (us)
217
217
Nh_thread = min (
218
- Int (fld (n_max_threads, Nij * Nij )),
218
+ Int (fld (n_max_threads, Ni * Nj )),
219
219
maximum_allowable_threads ()[3 ],
220
220
Nh,
221
221
)
222
222
Nh_blocks = cld (Nh, Nh_thread)
223
- @assert prod ((Nij, Nij , Nh_thread)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Nij, Nij , Nh_thread))) ,$n_max_threads )"
224
- return (; threads = (Nij, Nij , Nh_thread), blocks = (Nh_blocks,))
223
+ @assert prod ((Ni, Nj , Nh_thread)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Ni, Nj , Nh_thread))) ,$n_max_threads )"
224
+ return (; threads = (Ni, Nj , Nh_thread), blocks = (Nh_blocks,))
225
225
end
226
226
@inline function columnwise_universal_index (us:: UniversalSize )
227
227
(i, j, th) = CUDA. threadIdx ()
241
241
n_max_threads:: Integer ;
242
242
Nnames,
243
243
)
244
- (Nij, _ , _, _, Nh) = DataLayouts. universal_size (us)
245
- @assert prod ((Nij, Nij , Nnames)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Nij, Nij , Nnames))) ,$n_max_threads )"
246
- return (; threads = (Nij, Nij , Nnames), blocks = (Nh,))
244
+ (Ni, Nj , _, _, Nh) = DataLayouts. universal_size (us)
245
+ @assert prod ((Ni, Nj , Nnames)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Ni, Nj , Nnames))) ,$n_max_threads )"
246
+ return (; threads = (Ni, Nj , Nnames), blocks = (Nh,))
247
247
end
248
248
@inline function multiple_field_solve_universal_index (us:: UniversalSize )
249
249
(i, j, iname) = CUDA. threadIdx ()
@@ -258,12 +258,12 @@ end
258
258
us:: DataLayouts.UniversalSize ,
259
259
n_max_threads:: Integer = 256 ;
260
260
)
261
- (Nq, _ , _, Nv, Nh) = DataLayouts. universal_size (us)
262
- Nvthreads = min (fld (n_max_threads, Nq * Nq ), maximum_allowable_threads ()[3 ])
261
+ (Ni, Nj , _, Nv, Nh) = DataLayouts. universal_size (us)
262
+ Nvthreads = min (fld (n_max_threads, Ni * Nj ), maximum_allowable_threads ()[3 ])
263
263
Nvblocks = cld (Nv, Nvthreads)
264
- @assert prod ((Nq, Nq , Nvthreads)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Nq, Nq , Nvthreads))) ,$n_max_threads )"
265
- @assert Nq * Nq ≤ n_max_threads
266
- return (; threads = (Nq, Nq , Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
264
+ @assert prod ((Ni, Nj , Nvthreads)) ≤ n_max_threads " threads,n_max_threads=($(prod ((Ni, Nj , Nvthreads))) ,$n_max_threads )"
265
+ @assert Ni * Nj ≤ n_max_threads
266
+ return (; threads = (Ni, Nj , Nvthreads), blocks = (Nh, Nvblocks), Nvthreads)
267
267
end
268
268
@inline function spectral_universal_index (space:: Spaces.AbstractSpace )
269
269
i = threadIdx (). x
0 commit comments