@@ -4,6 +4,161 @@ import CUDA
44using CUDA: @cuda
55import ClimaCore. Remapping: _set_interpolated_values_device!
66
7+ # Bilinear in (s,t) ∈ [0,1]² (local to 2-point cell; reference element is [-1,1]²):
8+ # (1-s)(1-t)*c11 + s*(1-t)*c21 + (1-s)*t*c12 + s*t*c22
9+ @inline _bilinear (c11, c21, c22, c12, s, t) =
10+ (1 - s) * (1 - t) * c11 + s * (1 - t) * c21 + (1 - s) * t * c12 + s * t * c22
11+
12+ # Bilinear path on GPU: pure GPU kernels (no scalar indexing).
13+ # 2-point cell (i..i+1, j..j+1) containing (ξ1, ξ2); interpolate between nodes within the element.
14+ function ClimaCore. Remapping. _set_interpolated_values_bilinear! (
15+ out:: CUDA.CuArray ,
16+ fields:: AbstractArray{<:Fields.Field} ,
17+ scratch_corners,
18+ local_horiz_indices,
19+ vert_interpolation_weights:: AbstractArray ,
20+ vert_bounding_indices:: AbstractArray ,
21+ local_bilinear_s,
22+ local_bilinear_t,
23+ local_bilinear_i,
24+ local_bilinear_j,
25+ )
26+ field_values = tuple (map (f -> Fields. field_values (f), fields)... )
27+ num_horiz = length (local_horiz_indices)
28+ num_vert = length (vert_bounding_indices)
29+ num_fields = length (field_values)
30+ nitems = length (out)
31+ args = (
32+ out,
33+ local_horiz_indices,
34+ local_bilinear_s,
35+ local_bilinear_t,
36+ local_bilinear_i,
37+ local_bilinear_j,
38+ vert_interpolation_weights,
39+ vert_bounding_indices,
40+ field_values,
41+ )
42+ threads = threads_via_occupancy (set_interpolated_values_bilinear_3d_kernel!, args)
43+ p = linear_partition (nitems, threads)
44+ auto_launch! (
45+ set_interpolated_values_bilinear_3d_kernel!,
46+ args;
47+ threads_s = (p. threads,),
48+ blocks_s = (p. blocks,),
49+ )
50+ end
51+
52+ function set_interpolated_values_bilinear_3d_kernel! (
53+ out,
54+ local_horiz_indices,
55+ local_bilinear_s,
56+ local_bilinear_t,
57+ local_bilinear_i,
58+ local_bilinear_j,
59+ vert_interpolation_weights,
60+ vert_bounding_indices,
61+ field_values,
62+ )
63+ num_horiz = length (local_horiz_indices)
64+ num_vert = length (vert_bounding_indices)
65+ num_fields = length (field_values)
66+ inds = (num_horiz, num_vert, num_fields)
67+ i_thread =
68+ (CUDA. blockIdx (). x - Int32 (1 )) * CUDA. blockDim (). x + CUDA. threadIdx (). x
69+ 1 ≤ i_thread ≤ prod (inds) || return nothing
70+ (i_out, j_v, k) = CartesianIndices (map (x -> Base. OneTo (x), inds))[i_thread]. I
71+ @inbounds begin
72+ CI = CartesianIndex
73+ h = local_horiz_indices[i_out]
74+ v_lo, v_hi = vert_bounding_indices[j_v]
75+ A, B = vert_interpolation_weights[j_v]
76+ s = local_bilinear_s[i_out]
77+ t = local_bilinear_t[i_out]
78+ ii = local_bilinear_i[i_out]
79+ jj = local_bilinear_j[i_out]
80+ fvals = field_values[k]
81+ # Four nodes of 2-point cell: (ii,jj), (ii+1,jj), (ii+1,jj+1), (ii,jj+1)
82+ c11 = A * fvals[CI (ii, jj, 1 , v_lo, h)] + B * fvals[CI (ii, jj, 1 , v_hi, h)]
83+ c21 = A * fvals[CI (ii + 1 , jj, 1 , v_lo, h)] + B * fvals[CI (ii + 1 , jj, 1 , v_hi, h)]
84+ c22 =
85+ A * fvals[CI (ii + 1 , jj + 1 , 1 , v_lo, h)] +
86+ B * fvals[CI (ii + 1 , jj + 1 , 1 , v_hi, h)]
87+ c12 = A * fvals[CI (ii, jj + 1 , 1 , v_lo, h)] + B * fvals[CI (ii, jj + 1 , 1 , v_hi, h)]
88+ out[i_out, j_v, k] = _bilinear (c11, c21, c22, c12, s, t)
89+ end
90+ return nothing
91+ end
92+
93+ function ClimaCore. Remapping. _set_interpolated_values_bilinear! (
94+ out:: CUDA.CuArray ,
95+ fields:: AbstractArray{<:Fields.Field} ,
96+ scratch_corners,
97+ local_horiz_indices,
98+ :: Nothing ,
99+ :: Nothing ,
100+ local_bilinear_s,
101+ local_bilinear_t,
102+ local_bilinear_i,
103+ local_bilinear_j,
104+ )
105+ field_values = tuple (map (f -> Fields. field_values (f), fields)... )
106+ num_horiz = length (local_horiz_indices)
107+ num_fields = length (field_values)
108+ nitems = length (out)
109+ args = (
110+ out,
111+ local_horiz_indices,
112+ local_bilinear_s,
113+ local_bilinear_t,
114+ local_bilinear_i,
115+ local_bilinear_j,
116+ field_values,
117+ )
118+ threads = threads_via_occupancy (set_interpolated_values_bilinear_2d_kernel!, args)
119+ p = linear_partition (nitems, threads)
120+ auto_launch! (
121+ set_interpolated_values_bilinear_2d_kernel!,
122+ args;
123+ threads_s = (p. threads,),
124+ blocks_s = (p. blocks,),
125+ )
126+ end
127+
128+ function set_interpolated_values_bilinear_2d_kernel! (
129+ out,
130+ local_horiz_indices,
131+ local_bilinear_s,
132+ local_bilinear_t,
133+ local_bilinear_i,
134+ local_bilinear_j,
135+ field_values,
136+ )
137+ num_horiz = length (local_horiz_indices)
138+ num_fields = length (field_values)
139+ inds = (num_horiz, num_fields)
140+ i_thread =
141+ (CUDA. blockIdx (). x - Int32 (1 )) * CUDA. blockDim (). x + CUDA. threadIdx (). x
142+ 1 ≤ i_thread ≤ prod (inds) || return nothing
143+ (i_out, k) = CartesianIndices (map (x -> Base. OneTo (x), inds))[i_thread]. I
144+ @inbounds begin
145+ CI = CartesianIndex
146+ h = local_horiz_indices[i_out]
147+ s = local_bilinear_s[i_out]
148+ t = local_bilinear_t[i_out]
149+ ii = local_bilinear_i[i_out]
150+ jj = local_bilinear_j[i_out]
151+ fvals = field_values[k]
152+ # Four nodes of 2-point cell: (ii,jj), (ii+1,jj), (ii+1,jj+1), (ii,jj+1)
153+ c11 = fvals[CI (ii, jj, 1 , 1 , h)]
154+ c21 = fvals[CI (ii + 1 , jj, 1 , 1 , h)]
155+ c22 = fvals[CI (ii + 1 , jj + 1 , 1 , 1 , h)]
156+ c12 = fvals[CI (ii, jj + 1 , 1 , 1 , h)]
157+ out[i_out, k] = _bilinear (c11, c21, c22, c12, s, t)
158+ end
159+ return nothing
160+ end
161+
7162
8163function _set_interpolated_values_device! (
9164 out:: AbstractArray ,
0 commit comments