Skip to content

Commit 6f463fa

Browse files
committed
fix: prevent NaN issues with Z-Image on certain ROCm setups
1 parent 710169d commit 6f463fa

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

z_image.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@ namespace ZImage {
3030
JointAttention(int64_t hidden_size, int64_t head_dim, int64_t num_heads, int64_t num_kv_heads, bool qk_norm)
3131
: head_dim(head_dim), num_heads(num_heads), num_kv_heads(num_kv_heads), qk_norm(qk_norm) {
3232
blocks["qkv"] = std::make_shared<Linear>(hidden_size, (num_heads + num_kv_heads * 2) * head_dim, false);
33-
blocks["out"] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false);
33+
float scale = 1.f;
34+
#if GGML_USE_HIP
35+
// Prevent NaN issues with certain ROCm setups
36+
scale = 1.f / 16.f;
37+
#endif
38+
blocks["out"] = std::make_shared<Linear>(num_heads * head_dim, hidden_size, false, false, false, scale);
3439
if (qk_norm) {
3540
blocks["q_norm"] = std::make_shared<RMSNorm>(head_dim);
3641
blocks["k_norm"] = std::make_shared<RMSNorm>(head_dim);
@@ -93,7 +98,7 @@ namespace ZImage {
9398
#endif
9499
// The purpose of the scale here is to prevent NaN issues in certain situations.
95100
// For example, when using CUDA but the weights are k-quants.
96-
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false, false, force_prec_f32, 1.f / 128.f);
101+
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false, false, force_prec_f32, scale);
97102
blocks["w3"] = std::make_shared<Linear>(dim, hidden_dim, false);
98103
}
99104

@@ -667,4 +672,4 @@ namespace ZImage {
667672

668673
} // namespace ZImage
669674

670-
#endif // __Z_IMAGE_HPP__
675+
#endif // __Z_IMAGE_HPP__

0 commit comments

Comments
 (0)