Skip to content

Commit d5b05f7

Browse files
authored
feat: support independent sampler rng (leejet#978)
1 parent 6d6dc1b commit d5b05f7

File tree

4 files changed

+68
-26
lines changed

4 files changed

+68
-26
lines changed

examples/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ Options:
9595
--type weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K). If not specified, the default is the
9696
type of the weight file
9797
--rng RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)
98+
--sampler-rng sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng
9899
-s, --seed RNG seed (default: 42, use random seed for < 0)
99100
--sampling-method sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing,
100101
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)

examples/cli/main.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,22 @@ struct SDParams {
110110
int fps = 16;
111111
float vace_strength = 1.f;
112112

113-
float strength = 0.75f;
114-
float control_strength = 0.9f;
115-
rng_type_t rng_type = CUDA_RNG;
116-
int64_t seed = 42;
117-
bool verbose = false;
118-
bool offload_params_to_cpu = false;
119-
bool control_net_cpu = false;
120-
bool clip_on_cpu = false;
121-
bool vae_on_cpu = false;
122-
bool diffusion_flash_attn = false;
123-
bool diffusion_conv_direct = false;
124-
bool vae_conv_direct = false;
125-
bool canny_preprocess = false;
126-
bool color = false;
127-
int upscale_repeats = 1;
113+
float strength = 0.75f;
114+
float control_strength = 0.9f;
115+
rng_type_t rng_type = CUDA_RNG;
116+
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
117+
int64_t seed = 42;
118+
bool verbose = false;
119+
bool offload_params_to_cpu = false;
120+
bool control_net_cpu = false;
121+
bool clip_on_cpu = false;
122+
bool vae_on_cpu = false;
123+
bool diffusion_flash_attn = false;
124+
bool diffusion_conv_direct = false;
125+
bool vae_conv_direct = false;
126+
bool canny_preprocess = false;
127+
bool color = false;
128+
int upscale_repeats = 1;
128129

129130
// Photo Maker
130131
std::string photo_maker_path;
@@ -214,6 +215,7 @@ void print_params(SDParams params) {
214215
printf(" flow_shift: %.2f\n", params.flow_shift);
215216
printf(" strength(img2img): %.2f\n", params.strength);
216217
printf(" rng: %s\n", sd_rng_type_name(params.rng_type));
218+
printf(" sampler rng: %s\n", sd_rng_type_name(params.sampler_rng_type));
217219
printf(" seed: %zd\n", params.seed);
218220
printf(" batch_count: %d\n", params.batch_count);
219221
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
@@ -886,6 +888,20 @@ void parse_args(int argc, const char** argv, SDParams& params) {
886888
return 1;
887889
};
888890

891+
auto on_sampler_rng_arg = [&](int argc, const char** argv, int index) {
892+
if (++index >= argc) {
893+
return -1;
894+
}
895+
const char* arg = argv[index];
896+
params.sampler_rng_type = str_to_rng_type(arg);
897+
if (params.sampler_rng_type == RNG_TYPE_COUNT) {
898+
fprintf(stderr, "error: invalid sampler rng type %s\n",
899+
arg);
900+
return -1;
901+
}
902+
return 1;
903+
};
904+
889905
auto on_schedule_arg = [&](int argc, const char** argv, int index) {
890906
if (++index >= argc) {
891907
return -1;
@@ -1126,6 +1142,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
11261142
"--rng",
11271143
"RNG, one of [std_default, cuda, cpu], default: cuda(sd-webui), cpu(comfyui)",
11281144
on_rng_arg},
1145+
{"",
1146+
"--sampler-rng",
1147+
"sampler RNG, one of [std_default, cuda, cpu]. If not specified, use --rng",
1148+
on_sampler_rng_arg},
11291149
{"-s",
11301150
"--seed",
11311151
"RNG seed (default: 42, use random seed for < 0)",
@@ -1319,6 +1339,9 @@ std::string get_image_params(SDParams params, int64_t seed) {
13191339
parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", ";
13201340
parameter_string += "Model: " + sd_basename(params.model_path) + ", ";
13211341
parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", ";
1342+
if (params.sampler_rng_type != RNG_TYPE_COUNT) {
1343+
parameter_string += "Sampler RNG: " + std::string(sd_rng_type_name(params.sampler_rng_type)) + ", ";
1344+
}
13221345
parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_params.sample_method));
13231346
if (params.sample_params.scheduler != DEFAULT) {
13241347
parameter_string += " " + std::string(sd_schedule_name(params.sample_params.scheduler));
@@ -1758,6 +1781,7 @@ int main(int argc, const char* argv[]) {
17581781
params.n_threads,
17591782
params.wtype,
17601783
params.rng_type,
1784+
params.sampler_rng_type,
17611785
params.prediction,
17621786
params.lora_apply_mode,
17631787
params.offload_params_to_cpu,

stable-diffusion.cpp

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,11 @@ class StableDiffusionGGML {
9999
bool vae_decode_only = false;
100100
bool free_params_immediately = false;
101101

102-
std::shared_ptr<RNG> rng = std::make_shared<STDDefaultRNG>();
103-
int n_threads = -1;
104-
float scale_factor = 0.18215f;
105-
float shift_factor = 0.f;
102+
std::shared_ptr<RNG> rng = std::make_shared<PhiloxRNG>();
103+
std::shared_ptr<RNG> sampler_rng = nullptr;
104+
int n_threads = -1;
105+
float scale_factor = 0.18215f;
106+
float shift_factor = 0.f;
106107

107108
std::shared_ptr<Conditioner> cond_stage_model;
108109
std::shared_ptr<FrozenCLIPVisionEmbedder> clip_vision; // for svd or wan2.1 i2v
@@ -188,6 +189,16 @@ class StableDiffusionGGML {
188189
}
189190
}
190191

192+
std::shared_ptr<RNG> get_rng(rng_type_t rng_type) {
193+
if (rng_type == STD_DEFAULT_RNG) {
194+
return std::make_shared<STDDefaultRNG>();
195+
} else if (rng_type == CPU_RNG) {
196+
return std::make_shared<MT19937RNG>();
197+
} else { // default: CUDA_RNG
198+
return std::make_shared<PhiloxRNG>();
199+
}
200+
}
201+
191202
bool init(const sd_ctx_params_t* sd_ctx_params) {
192203
n_threads = sd_ctx_params->n_threads;
193204
vae_decode_only = sd_ctx_params->vae_decode_only;
@@ -197,12 +208,11 @@ class StableDiffusionGGML {
197208
use_tiny_autoencoder = taesd_path.size() > 0;
198209
offload_params_to_cpu = sd_ctx_params->offload_params_to_cpu;
199210

200-
if (sd_ctx_params->rng_type == STD_DEFAULT_RNG) {
201-
rng = std::make_shared<STDDefaultRNG>();
202-
} else if (sd_ctx_params->rng_type == CUDA_RNG) {
203-
rng = std::make_shared<PhiloxRNG>();
204-
} else if (sd_ctx_params->rng_type == CPU_RNG) {
205-
rng = std::make_shared<MT19937RNG>();
211+
rng = get_rng(sd_ctx_params->rng_type);
212+
if (sd_ctx_params->sampler_rng_type != RNG_TYPE_COUNT) {
213+
sampler_rng = get_rng(sd_ctx_params->sampler_rng_type);
214+
} else {
215+
sampler_rng = rng;
206216
}
207217

208218
ggml_log_set(ggml_log_callback_default, nullptr);
@@ -1736,7 +1746,7 @@ class StableDiffusionGGML {
17361746
return denoised;
17371747
};
17381748

1739-
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta);
1749+
sample_k_diffusion(method, denoise, work_ctx, x, sigmas, sampler_rng, eta);
17401750

17411751
if (inverse_noise_scaling) {
17421752
x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x);
@@ -2291,6 +2301,7 @@ void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) {
22912301
sd_ctx_params->n_threads = get_num_physical_cores();
22922302
sd_ctx_params->wtype = SD_TYPE_COUNT;
22932303
sd_ctx_params->rng_type = CUDA_RNG;
2304+
sd_ctx_params->sampler_rng_type = RNG_TYPE_COUNT;
22942305
sd_ctx_params->prediction = DEFAULT_PRED;
22952306
sd_ctx_params->lora_apply_mode = LORA_APPLY_AUTO;
22962307
sd_ctx_params->offload_params_to_cpu = false;
@@ -2332,6 +2343,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
23322343
"n_threads: %d\n"
23332344
"wtype: %s\n"
23342345
"rng_type: %s\n"
2346+
"sampler_rng_type: %s\n"
23352347
"prediction: %s\n"
23362348
"offload_params_to_cpu: %s\n"
23372349
"keep_clip_on_cpu: %s\n"
@@ -2362,6 +2374,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
23622374
sd_ctx_params->n_threads,
23632375
sd_type_name(sd_ctx_params->wtype),
23642376
sd_rng_type_name(sd_ctx_params->rng_type),
2377+
sd_rng_type_name(sd_ctx_params->sampler_rng_type),
23652378
sd_prediction_name(sd_ctx_params->prediction),
23662379
BOOL_STR(sd_ctx_params->offload_params_to_cpu),
23672380
BOOL_STR(sd_ctx_params->keep_clip_on_cpu),
@@ -2823,6 +2836,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
28232836
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, batch_count, cur_seed);
28242837

28252838
sd_ctx->sd->rng->manual_seed(cur_seed);
2839+
sd_ctx->sd->sampler_rng->manual_seed(cur_seed);
28262840
struct ggml_tensor* x_t = init_latent;
28272841
struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1);
28282842
ggml_ext_im_set_randn_f32(noise, sd_ctx->sd->rng);
@@ -2949,6 +2963,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
29492963
seed = rand();
29502964
}
29512965
sd_ctx->sd->rng->manual_seed(seed);
2966+
sd_ctx->sd->sampler_rng->manual_seed(seed);
29522967

29532968
int sample_steps = sd_img_gen_params->sample_params.sample_steps;
29542969

@@ -3240,6 +3255,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
32403255
}
32413256

32423257
sd_ctx->sd->rng->manual_seed(seed);
3258+
sd_ctx->sd->sampler_rng->manual_seed(seed);
32433259

32443260
int64_t t0 = ggml_time_ms();
32453261

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ typedef struct {
173173
int n_threads;
174174
enum sd_type_t wtype;
175175
enum rng_type_t rng_type;
176+
enum rng_type_t sampler_rng_type;
176177
enum prediction_t prediction;
177178
enum lora_apply_mode_t lora_apply_mode;
178179
bool offload_params_to_cpu;

0 commit comments

Comments
 (0)