1313
1414layout (std430) buffer ;
1515
16- layout (set = 0 , binding = 0 , ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
17- layout (set = 0 , binding = 1 ) uniform PRECISION sampler3D image_in;
18- layout (set = 0 , binding = 2 ) uniform PRECISION sampler3D weight_in;
19- layout (set = 0 , binding = 3 ) uniform PRECISION sampler3D bias_in;
20- layout (set = 0 , binding = 4 ) uniform PRECISION sampler3D mean_in;
21- layout (set = 0 , binding = 5 ) uniform PRECISION sampler3D var_in;
16+ #include "indexing_utils.h"
2217
23- layout (set = 0 , binding = 6 ) uniform PRECISION restrict OutLimits {
24- ivec3 out_limits;
25- };
18+ ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
19+ ${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
20+ ${layout_declare_tensor(B, "r", "weight_in", DTYPE, STORAGE)}
21+ ${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
22+ ${layout_declare_tensor(B, "r", "mean_in", DTYPE, STORAGE)}
23+ ${layout_declare_tensor(B, "r", "var_in", DTYPE, STORAGE)}
2624
27- layout (set = 0 , binding = 7 ) uniform PRECISION restrict Params {
28- float eps;
29- };
30-
31- layout (set = 0 , binding = 8 ) uniform PRECISION restrict Params2 {
32- int num_texel_per_batch;
33- };
25+ ${layout_declare_ubo(B, "ivec3 ", "out_limits")}
26+ ${layout_declare_ubo(B, "float ", "eps")}
27+ ${layout_declare_ubo(B, "int ", "num_texel_per_batch")}
3428
3529layout (local_size_x_id = 0 , local_size_y_id = 1 , local_size_z_id = 2 ) in ;
3630
@@ -40,16 +34,16 @@ void main() {
4034 return ;
4135 }
4236
43- VEC4_T v = VEC4_T(texelFetch(image_in , pos, 0 ));
37+ VEC4_T v = VEC4_T(load_texel(t_in , pos));
4438
4539 ivec3 param_pos = ivec3 (pos.z % num_texel_per_batch, 0 , 0 );
4640
47- VEC4_T weight = VEC4_T(texelFetch (weight_in, param_pos, 0 ));
48- VEC4_T bias = VEC4_T(texelFetch (bias_in, param_pos, 0 ));
49- VEC4_T mean = VEC4_T(texelFetch (mean_in, param_pos, 0 ));
50- VEC4_T var = VEC4_T(texelFetch (var_in, param_pos, 0 ));
41+ VEC4_T weight = VEC4_T(load_texel (weight_in, param_pos));
42+ VEC4_T bias = VEC4_T(load_texel (bias_in, param_pos));
43+ VEC4_T mean = VEC4_T(load_texel (mean_in, param_pos));
44+ VEC4_T var = VEC4_T(load_texel (var_in, param_pos));
5145
5246 v = ((v - mean) / sqrt (var + eps)) * weight + bias;
5347
54- imageStore(image_out , pos, v);
48+ write_texel(t_out , pos, v);
5549}
0 commit comments