| 
139 | 139 |     "int in_zero_point, bool channel_last=False) -> (Tensor out)"  | 
140 | 140 | )  | 
141 | 141 | lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")  | 
 | 142 | +lib.define("rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)")  | 
142 | 143 | lib.define(  | 
143 | 144 |     "transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "  | 
144 | 145 |     "int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"  | 
 | 
210 | 211 |     "fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)"  | 
211 | 212 | )  | 
212 | 213 | lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)")  | 
 | 214 | +lib.define(  | 
 | 215 | +    "rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)"  | 
 | 216 | +)  | 
213 | 217 | lib.define(  | 
214 | 218 |     "quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "  | 
215 | 219 |     "Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"  | 
@@ -615,6 +619,15 @@ def linalg_vector_norm_meta(  | 
615 | 619 |     return X.new_empty([], dtype=X.dtype)  | 
616 | 620 | 
 
  | 
617 | 621 | 
 
  | 
 | 622 | +@register_fake("cadence::rms_norm")  | 
 | 623 | +def rms_norm_meta(  | 
 | 624 | +    X: torch.Tensor,  | 
 | 625 | +    eps: float,  | 
 | 626 | +    weight: torch.Tensor,  | 
 | 627 | +) -> torch.Tensor:  | 
 | 628 | +    return X.new_empty(X.shape, dtype=X.dtype)  | 
 | 629 | + | 
 | 630 | + | 
618 | 631 | @register_fake("cadence::requantize")  | 
619 | 632 | def requantize_meta(  | 
620 | 633 |     input: torch.Tensor,  | 
 | 
0 commit comments