diff --git a/examples/qualcomm/oss_scripts/llama/model/static_llama.py b/examples/qualcomm/oss_scripts/llama/model/static_llama.py index 192f23de302..49b38445c6a 100755 --- a/examples/qualcomm/oss_scripts/llama/model/static_llama.py +++ b/examples/qualcomm/oss_scripts/llama/model/static_llama.py @@ -70,7 +70,7 @@ def __init__(self, config: ModelArgs, output_new_cache_only=False): self.scale = float(self.head_dim) ** 0.5 - if config.enable_r3: + if hasattr(config, "enable_r3") and config.enable_r3: self.register_buffer( "r3_weight", torch.tensor( @@ -186,11 +186,11 @@ def forward_sha( ] for i in range(len(q)): q[i] = apply_rotary_emb_single(q[i], freqs_cos, freqs_sin) - if self.config.enable_r3: + if hasattr(self.config, "enable_r3") and self.config.enable_r3: q[i] = torch.matmul(q[i], self.r3_weight.T) for i in range(len(k)): k[i] = apply_rotary_emb_single(k[i], freqs_cos, freqs_sin) - if self.config.enable_r3: + if hasattr(self.config, "enable_r3") and self.config.enable_r3: k[i] = torch.matmul(k[i], self.r3_weight.T) k[i] = k[i].transpose(1, 2)