| 
15 | 15 | 
 
  | 
16 | 16 | #define VEC4_T ${texel_type(DTYPE)}  | 
17 | 17 | 
 
  | 
 | 18 | +#define T ${texel_component_type(DTYPE)}  | 
 | 19 | + | 
18 | 20 | layout(std430) buffer;  | 
19 | 21 | 
 
  | 
20 | 22 | ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}  | 
@@ -48,37 +50,97 @@ void main() {  | 
48 | 50 | 
 
  | 
49 | 51 |   const int width = int(sizes.x);  | 
50 | 52 | 
 
  | 
51 |  | -  VEC4_T mean = VEC4_T(0);  | 
52 |  | -  VEC4_T delta = VEC4_T(0);  | 
53 |  | -  VEC4_T delta2 = VEC4_T(0);  | 
54 |  | -  VEC4_T M2 = VEC4_T(0);  | 
55 |  | - | 
56 |  | -  // Use Welford's online algorithm to compute mean and variance in one pass  | 
57 |  | -  // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm  | 
58 |  | -  ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);  | 
59 |  | -  for (int w = 0; w < width; ++w) {  | 
60 |  | -    in_pos[in_axis_map.x] = w;  | 
61 |  | -    VEC4_T v = load_texel(t_in, in_pos);  | 
62 |  | -    delta = v - mean;  | 
63 |  | -    mean += delta / (w + 1);  | 
64 |  | -    delta2 = v - mean;  | 
65 |  | -    M2 += delta * delta2;  | 
66 |  | -  }  | 
67 |  | - | 
68 |  | -  VEC4_T var = M2 / width;  | 
69 |  | -  VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));  | 
70 |  | -  VEC4_T offset = -rstd * mean;  | 
71 |  | - | 
72 |  | -  for (int w = 0; w < width; ++w) {  | 
73 |  | -    in_pos[in_axis_map.x] = w;  | 
74 |  | -    VEC4_T v = load_texel(t_in, in_pos);  | 
75 |  | -    // broadcasting  | 
76 |  | -    VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;  | 
77 |  | -    VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;  | 
78 |  | -    VEC4_T outtex = (v * rstd + offset) * weight + bias;  | 
79 |  | -    write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);  | 
 | 53 | +  if (in_packed_dim != W_DIM) {  | 
 | 54 | +    VEC4_T mean = VEC4_T(0);  | 
 | 55 | +    VEC4_T delta = VEC4_T(0);  | 
 | 56 | +    VEC4_T delta2 = VEC4_T(0);  | 
 | 57 | +    VEC4_T M2 = VEC4_T(0);  | 
 | 58 | + | 
 | 59 | +    // Use Welford's online algorithm to compute mean and variance in one pass  | 
 | 60 | +    // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm  | 
 | 61 | +    ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);  | 
 | 62 | +    for (int w = 0; w < width; ++w) {  | 
 | 63 | +      in_pos[in_axis_map.x] = w;  | 
 | 64 | +      VEC4_T v = load_texel(t_in, in_pos);  | 
 | 65 | +      delta = v - mean;  | 
 | 66 | +      mean += delta / (w + 1);  | 
 | 67 | +      delta2 = v - mean;  | 
 | 68 | +      M2 += delta * delta2;  | 
 | 69 | +    }  | 
 | 70 | + | 
 | 71 | +    VEC4_T var = M2 / width;  | 
 | 72 | +    VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5));  | 
 | 73 | +    VEC4_T offset = -rstd * mean;  | 
 | 74 | + | 
 | 75 | +    for (int w = 0; w < width; ++w) {  | 
 | 76 | +      in_pos[in_axis_map.x] = w;  | 
 | 77 | +      VEC4_T v = load_texel(t_in, in_pos);  | 
 | 78 | +      // broadcasting  | 
 | 79 | +      VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx;  | 
 | 80 | +      VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx;  | 
 | 81 | +      VEC4_T outtex = (v * rstd + offset) * weight + bias;  | 
 | 82 | +      write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);  | 
 | 83 | +    }  | 
 | 84 | + | 
 | 85 | +    write_texel(t_mean, lpos, mean);  | 
 | 86 | +    write_texel(t_rstd, lpos, rstd);  | 
 | 87 | +  } else {  | 
 | 88 | +    const int packed_width = divup4(width);  | 
 | 89 | + | 
 | 90 | +    T mean = T(0);  | 
 | 91 | +    T delta = T(0);  | 
 | 92 | +    T delta2 = T(0);  | 
 | 93 | +    T M2 = T(0);  | 
 | 94 | +    // Use Welford's online algorithm to compute mean and variance in one pass  | 
 | 95 | +    // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm  | 
 | 96 | +    ivec3 in_pos = lpos_to_pos(lpos, in_axis_map);  | 
 | 97 | +    T width_counter = T(1);  | 
 | 98 | + | 
 | 99 | +    const bool has_unaligned_width = (width & 0x3) != 0;  | 
 | 100 | +    const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width);  | 
 | 101 | + | 
 | 102 | +    // iterate through texels that are fully packed ie. has 4 components  | 
 | 103 | +    for (int w = 0; w < fully_packed_4_comp_count; ++w) {  | 
 | 104 | +      in_pos[in_axis_map.x] = w;  | 
 | 105 | +      VEC4_T v = load_texel(t_in, in_pos);  | 
 | 106 | +      for (int i=0; i<4; i++) {  | 
 | 107 | +        delta = v[i] - mean;  | 
 | 108 | +        mean += delta / width_counter;  | 
 | 109 | +        delta2 = v[i] - mean;  | 
 | 110 | +        M2 += delta * delta2;  | 
 | 111 | +        width_counter++;  | 
 | 112 | +      }  | 
 | 113 | +    }  | 
 | 114 | + | 
 | 115 | +    // handle last texel if its not 4 aligned  | 
 | 116 | +    if (has_unaligned_width) {  | 
 | 117 | +      in_pos[in_axis_map.x] = fully_packed_4_comp_count;  | 
 | 118 | +      const int remaining_width = width & 0x3;  | 
 | 119 | + | 
 | 120 | +      VEC4_T v = load_texel(t_in, in_pos);  | 
 | 121 | +      for (int i=0; i<remaining_width; i++) {  | 
 | 122 | +        delta = v[i] - mean;  | 
 | 123 | +        mean += delta / width_counter;  | 
 | 124 | +        delta2 = v[i] - mean;  | 
 | 125 | +        M2 += delta * delta2;  | 
 | 126 | +        width_counter++;  | 
 | 127 | +      }  | 
 | 128 | +    }  | 
 | 129 | + | 
 | 130 | +    T var = M2 / (width_counter - 1);  | 
 | 131 | +    T rstd = inversesqrt(var + epsilon);  | 
 | 132 | +    T offset = -rstd * mean;  | 
 | 133 | + | 
 | 134 | +    for (int w = 0; w < packed_width; ++w) {  | 
 | 135 | +      in_pos[in_axis_map.x] = w;  | 
 | 136 | +      VEC4_T v = load_texel(t_in, in_pos);  | 
 | 137 | +      VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0));  | 
 | 138 | +      VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0));  | 
 | 139 | +      VEC4_T outtex = (v * rstd + offset) * weight + bias;  | 
 | 140 | +      write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map);  | 
 | 141 | +    }  | 
 | 142 | + | 
 | 143 | +    write_texel(t_mean, lpos, VEC4_T(mean));  | 
 | 144 | +    write_texel(t_rstd, lpos, VEC4_T(rstd));  | 
80 | 145 |   }  | 
81 |  | - | 
82 |  | -  write_texel(t_mean, lpos, mean);  | 
83 |  | -  write_texel(t_rstd, lpos, rstd);  | 
84 | 146 | }  | 
0 commit comments