@@ -30,7 +30,12 @@ namespace ZImage {
3030 JointAttention (int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm)
3131 : head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) {
3232 blocks[" qkv" ] = std::make_shared<Linear>(hidden_size, (num_heads + num_kv_heads * 2 ) * head_dim, false );
33- blocks[" out" ] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false );
33+ float scale = 1 .f ;
34+ #if GGML_USE_HIP
35+ // Prevent NaN issues with certain ROCm setups
36+ scale = 1 .f / 16 .f ;
37+ #endif
38+ blocks[" out" ] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false , false , false , scale);
3439 if (qk_norm) {
3540 blocks[" q_norm" ] = std::make_shared<RMSNorm>(head_dim);
3641 blocks[" k_norm" ] = std::make_shared<RMSNorm>(head_dim);
@@ -93,7 +98,7 @@ namespace ZImage {
9398#endif
9499 // The purpose of the scale here is to prevent NaN issues in certain situations.
95100 // For example, when using CUDA but the weights are k-quants.
96- blocks[" w2" ] = std::make_shared<Linear>(hidden_dim, dim, false , false , force_prec_f32, 1 . f / 128 . f );
101+ blocks[" w2" ] = std::make_shared<Linear>(hidden_dim, dim, false , false , force_prec_f32, scale );
97102 blocks[" w3" ] = std::make_shared<Linear>(dim, hidden_dim, false );
98103 }
99104
@@ -667,4 +672,4 @@ namespace ZImage {
667672
668673} // namespace ZImage
669674
670- #endif // __Z_IMAGE_HPP__
675+ #endif // __Z_IMAGE_HPP__
0 commit comments