@@ -549,8 +549,48 @@ struct Denoiser {
549549};
550550
551551struct CompVisDenoiser : public Denoiser {
552- float sigmas[TIMESTEPS];
553- float log_sigmas[TIMESTEPS];
552+
553+ private:
554+ struct Constants {
555+ const float beta_start = 0 .00085f ;
556+ const float beta_end = 0 .0120f ;
557+ float alphas_cumprods[TIMESTEPS];
558+ float sigmas[TIMESTEPS];
559+ float log_sigmas[TIMESTEPS];
560+ Constants () {
561+ double ls_sqrt = std::sqrt (static_cast <double >(beta_start));
562+ double le_sqrt = std::sqrt (static_cast <double >(beta_end));
563+ double step = (le_sqrt - ls_sqrt) / (TIMESTEPS - 1 );
564+ double alphas_cumprod = 1.0 ;
565+
566+ for (int i = 0 ; i < TIMESTEPS; ++i) {
567+ double sqrt_beta = ls_sqrt + step * i;
568+ alphas_cumprod *= (1.0 - (sqrt_beta * sqrt_beta));
569+ double sigma = std::sqrt ((1.0 - alphas_cumprod) / alphas_cumprod);
570+ alphas_cumprods[i] = static_cast <float >(alphas_cumprod);
571+ sigmas[i] = static_cast <float >(sigma);
572+ log_sigmas[i] = static_cast <float >(std::log (sigma));
573+ }
574+ }
575+ static const Constants& get_instance () {
576+ static Constants instance;
577+ return instance;
578+ }
579+ };
580+
581+ const float *sigmas = get_sigmas();
582+ const float *log_sigmas = get_log_sigmas();
583+
584+ public:
585+ static const float * get_sigmas () {
586+ return Constants::get_instance ().sigmas ;
587+ }
588+ static const float * get_log_sigmas () {
589+ return Constants::get_instance ().log_sigmas ;
590+ }
591+ static const float * get_alphas_cumprods () {
592+ return Constants::get_instance ().alphas_cumprods ;
593+ }
554594
555595 float sigma_data = 1 .0f ;
556596
@@ -564,20 +604,12 @@ struct CompVisDenoiser : public Denoiser {
564604
565605 float sigma_to_t (float sigma) override {
566606 float log_sigma = std::log (sigma);
567- std::vector<float > dists;
568- dists.reserve (TIMESTEPS);
569- for (float log_sigma_val : log_sigmas) {
570- dists.push_back (log_sigma - log_sigma_val);
571- }
607+ const float * high_ptr = std::upper_bound (log_sigmas, log_sigmas + TIMESTEPS, log_sigma);
572608
573- int low_idx = 0 ;
574- for (size_t i = 0 ; i < TIMESTEPS; i++) {
575- if (dists[i] >= 0 ) {
576- low_idx++;
577- }
578- }
579- low_idx = std::min (std::max (low_idx - 1 , 0 ), TIMESTEPS - 2 );
580- int high_idx = low_idx + 1 ;
609+ int high_idx = static_cast <int >(high_ptr - log_sigmas);
610+ int low_idx = high_idx - 1 ;
611+ low_idx = std::clamp (low_idx, 0 , TIMESTEPS - 2 );
612+ high_idx = low_idx + 1 ;
581613
582614 float low = log_sigmas[low_idx];
583615 float high = log_sigmas[high_idx];
@@ -1566,27 +1598,15 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
15661598 const std::vector<float >& sigmas,
15671599 std::shared_ptr<RNG> rng,
15681600 float eta) {
1569- float beta_start = 0 .00085f ;
1570- float beta_end = 0 .0120f ;
1571- std::vector<double > alphas_cumprod (TIMESTEPS);
1572- std::vector<double > compvis_sigmas (TIMESTEPS);
1573- for (int i = 0 ; i < TIMESTEPS; i++) {
1574- alphas_cumprod[i] =
1575- (i == 0 ? 1 .0f : alphas_cumprod[i - 1 ]) *
1576- (1 .0f -
1577- std::pow (sqrtf (beta_start) +
1578- (sqrtf (beta_end) - sqrtf (beta_start)) *
1579- ((float )i / (TIMESTEPS - 1 )),
1580- 2 ));
1581- compvis_sigmas[i] =
1582- std::sqrt ((1 - alphas_cumprod[i]) / alphas_cumprod[i]);
1583- }
1601+
1602+ const float * alphas_cumprod = CompVisDenoiser::get_alphas_cumprods ();
1603+ const float * compvis_sigmas = CompVisDenoiser::get_sigmas ();
15841604
15851605 auto get_timestep_from_sigma = [&](float s) -> int {
1586- auto it = std::lower_bound (compvis_sigmas. begin () , compvis_sigmas. end () , s);
1587- if (it == compvis_sigmas. begin () ) return 0 ;
1588- if (it == compvis_sigmas. end () ) return TIMESTEPS - 1 ;
1589- int idx_high = static_cast <int >(std::distance (compvis_sigmas. begin () , it));
1606+ auto it = std::lower_bound (compvis_sigmas, compvis_sigmas + TIMESTEPS , s);
1607+ if (it == compvis_sigmas) return 0 ;
1608+ if (it == compvis_sigmas + TIMESTEPS ) return TIMESTEPS - 1 ;
1609+ int idx_high = static_cast <int >(std::distance (compvis_sigmas, it));
15901610 int idx_low = idx_high - 1 ;
15911611 if (std::abs (compvis_sigmas[idx_high] - s) < std::abs (compvis_sigmas[idx_low] - s)) {
15921612 return idx_high;
@@ -1612,7 +1632,7 @@ static sd::Tensor<float> sample_tcd(denoise_cb_t model,
16121632 float alpha_prod_t = 1 .0f / (sigma * sigma + 1 .0f );
16131633 float beta_prod_t = 1 .0f - alpha_prod_t ;
16141634 float alpha_prod_t_prev = 1 .0f / (sigma_to * sigma_to + 1 .0f );
1615- float alpha_prod_s = static_cast < float >( alphas_cumprod[timestep_s]) ;
1635+ float alpha_prod_s = alphas_cumprod[timestep_s];
16161636 float beta_prod_s = 1 .0f - alpha_prod_s;
16171637
16181638 sd::Tensor<float > pred_original_sample = ((x / std::sqrt (sigma * sigma + 1 )) -
0 commit comments