Skip to content

Commit 1798ec0

Browse files
committed
fix nan issue that occurs when using CUDA with k-quants weights
1 parent 2fec01d commit 1798ec0

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

z_image.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,15 @@ namespace ZImage {
8585
}
8686
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) / multiple_of);
8787
blocks["w1"] = std::make_shared<Linear>(dim, hidden_dim, false);
88-
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false);
88+
89+
bool force_prec_f32 = false;
90+
float scale = 1.f / 128.f;
91+
#ifdef SD_USE_VULKAN
92+
force_prec_f32 = true;
93+
#endif
94+
// The purpose of the scale here is to prevent NaN issues in certain situations.
95+
// For example, when using CUDA but the weights are k-quants.
96+
blocks["w2"] = std::make_shared<Linear>(hidden_dim, dim, false, false, force_prec_f32, 1.f / 128.f);
8997
blocks["w3"] = std::make_shared<Linear>(dim, hidden_dim, false);
9098
}
9199

0 commit comments

Comments
 (0)