|  | 
| 15 | 15 | 
 | 
| 16 | 16 | #define VEC4_T ${texel_type(DTYPE)} | 
| 17 | 17 | 
 | 
|  | 18 | +#define T ${texel_component_type(DTYPE)} | 
|  | 19 | + | 
| 18 | 20 | layout(std430) buffer; | 
| 19 | 21 | 
 | 
| 20 | 22 | ${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} | 
| @@ -48,37 +50,97 @@ void main() { | 
| 48 | 50 | 
 | 
| 49 | 51 |   const int width = int(sizes.x); | 
| 50 | 52 | 
 | 
| 51 |  | -  VEC4_T mean = VEC4_T(0); | 
| 52 |  | -  VEC4_T delta = VEC4_T(0); | 
| 53 |  | -  VEC4_T delta2 = VEC4_T(0); | 
| 54 |  | -  VEC4_T M2 = VEC4_T(0); | 
| 55 |  | - | 
| 56 |  | -  // Use Welford's online algorithm to compute mean and variance in one pass | 
| 57 |  | -  // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm | 
| 58 |  | -  ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); | 
| 59 |  | -  for (int w = 0; w < width; ++w) { | 
| 60 |  | -    in_pos[in_axis_map.x] = w; | 
| 61 |  | -    VEC4_T v = load_texel(t_in, in_pos); | 
| 62 |  | -    delta = v - mean; | 
| 63 |  | -    mean += delta / (w + 1); | 
| 64 |  | -    delta2 = v - mean; | 
| 65 |  | -    M2 += delta * delta2; | 
| 66 |  | -  } | 
| 67 |  | - | 
| 68 |  | -  VEC4_T var = M2 / width; | 
| 69 |  | -  VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); | 
| 70 |  | -  VEC4_T offset = -rstd * mean; | 
| 71 |  | - | 
| 72 |  | -  for (int w = 0; w < width; ++w) { | 
| 73 |  | -    in_pos[in_axis_map.x] = w; | 
| 74 |  | -    VEC4_T v = load_texel(t_in, in_pos); | 
| 75 |  | -    // broadcasting | 
| 76 |  | -    VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; | 
| 77 |  | -    VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; | 
| 78 |  | -    VEC4_T outtex = (v * rstd + offset) * weight + bias; | 
| 79 |  | -    write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); | 
|  | 53 | +  if (in_packed_dim != W_DIM) { | 
|  | 54 | +    VEC4_T mean = VEC4_T(0); | 
|  | 55 | +    VEC4_T delta = VEC4_T(0); | 
|  | 56 | +    VEC4_T delta2 = VEC4_T(0); | 
|  | 57 | +    VEC4_T M2 = VEC4_T(0); | 
|  | 58 | + | 
|  | 59 | +    // Use Welford's online algorithm to compute mean and variance in one pass | 
|  | 60 | +    // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm | 
|  | 61 | +    ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); | 
|  | 62 | +    for (int w = 0; w < width; ++w) { | 
|  | 63 | +      in_pos[in_axis_map.x] = w; | 
|  | 64 | +      VEC4_T v = load_texel(t_in, in_pos); | 
|  | 65 | +      delta = v - mean; | 
|  | 66 | +      mean += delta / (w + 1); | 
|  | 67 | +      delta2 = v - mean; | 
|  | 68 | +      M2 += delta * delta2; | 
|  | 69 | +    } | 
|  | 70 | + | 
|  | 71 | +    VEC4_T var = M2 / width; | 
|  | 72 | +    VEC4_T rstd = pow(var + epsilon, VEC4_T(-0.5)); | 
|  | 73 | +    VEC4_T offset = -rstd * mean; | 
|  | 74 | + | 
|  | 75 | +    for (int w = 0; w < width; ++w) { | 
|  | 76 | +      in_pos[in_axis_map.x] = w; | 
|  | 77 | +      VEC4_T v = load_texel(t_in, in_pos); | 
|  | 78 | +      // broadcasting | 
|  | 79 | +      VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)).xxxx; | 
|  | 80 | +      VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)).xxxx; | 
|  | 81 | +      VEC4_T outtex = (v * rstd + offset) * weight + bias; | 
|  | 82 | +      write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); | 
|  | 83 | +    } | 
|  | 84 | + | 
|  | 85 | +    write_texel(t_mean, lpos, mean); | 
|  | 86 | +    write_texel(t_rstd, lpos, rstd); | 
|  | 87 | +  } else { | 
|  | 88 | +    const int packed_width = divup4(width); | 
|  | 89 | + | 
|  | 90 | +    T mean = T(0); | 
|  | 91 | +    T delta = T(0); | 
|  | 92 | +    T delta2 = T(0); | 
|  | 93 | +    T M2 = T(0); | 
|  | 94 | +    // Use Welford's online algorithm to compute mean and variance in one pass | 
|  | 95 | +    // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm | 
|  | 96 | +    ivec3 in_pos = lpos_to_pos(lpos, in_axis_map); | 
|  | 97 | +    T width_counter = T(1); | 
|  | 98 | + | 
|  | 99 | +    const bool has_unaligned_width = (width & 0x3) != 0; | 
|  | 100 | +    const int fully_packed_4_comp_count = packed_width - mix(0, 1, has_unaligned_width); | 
|  | 101 | + | 
|  | 102 | +    // iterate through texels that are fully packed ie. has 4 components | 
|  | 103 | +    for (int w = 0; w < fully_packed_4_comp_count; ++w) { | 
|  | 104 | +      in_pos[in_axis_map.x] = w; | 
|  | 105 | +      VEC4_T v = load_texel(t_in, in_pos); | 
|  | 106 | +      for (int i=0; i<4; i++) { | 
|  | 107 | +        delta = v[i] - mean; | 
|  | 108 | +        mean += delta / width_counter; | 
|  | 109 | +        delta2 = v[i] - mean; | 
|  | 110 | +        M2 += delta * delta2; | 
|  | 111 | +        width_counter++; | 
|  | 112 | +      } | 
|  | 113 | +    } | 
|  | 114 | + | 
|  | 115 | +    // handle last texel if its not 4 aligned | 
|  | 116 | +    if (has_unaligned_width) { | 
|  | 117 | +      in_pos[in_axis_map.x] = fully_packed_4_comp_count; | 
|  | 118 | +      const int remaining_width = width & 0x3; | 
|  | 119 | + | 
|  | 120 | +      VEC4_T v = load_texel(t_in, in_pos); | 
|  | 121 | +      for (int i=0; i<remaining_width; i++) { | 
|  | 122 | +        delta = v[i] - mean; | 
|  | 123 | +        mean += delta / width_counter; | 
|  | 124 | +        delta2 = v[i] - mean; | 
|  | 125 | +        M2 += delta * delta2; | 
|  | 126 | +        width_counter++; | 
|  | 127 | +      } | 
|  | 128 | +    } | 
|  | 129 | + | 
|  | 130 | +    T var = M2 / (width_counter - 1); | 
|  | 131 | +    T rstd = inversesqrt(var + epsilon); | 
|  | 132 | +    T offset = -rstd * mean; | 
|  | 133 | + | 
|  | 134 | +    for (int w = 0; w < packed_width; ++w) { | 
|  | 135 | +      in_pos[in_axis_map.x] = w; | 
|  | 136 | +      VEC4_T v = load_texel(t_in, in_pos); | 
|  | 137 | +      VEC4_T weight = load_texel(t_weight, ivec3(w, 0, 0)); | 
|  | 138 | +      VEC4_T bias = load_texel(t_bias, ivec3(w, 0, 0)); | 
|  | 139 | +      VEC4_T outtex = (v * rstd + offset) * weight + bias; | 
|  | 140 | +      write_texel_lpos(t_out, ivec3(w, lpos.y, lpos.z), outtex, out_axis_map); | 
|  | 141 | +    } | 
|  | 142 | + | 
|  | 143 | +    write_texel(t_mean, lpos, VEC4_T(mean)); | 
|  | 144 | +    write_texel(t_rstd, lpos, VEC4_T(rstd)); | 
| 80 | 145 |   } | 
| 81 |  | - | 
| 82 |  | -  write_texel(t_mean, lpos, mean); | 
| 83 |  | -  write_texel(t_rstd, lpos, rstd); | 
| 84 | 146 | } | 
0 commit comments