diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/evo_search.yml b/ppdiffusers/examples/class_conditional_image_generation/DiT/evo_search.yml new file mode 100644 index 000000000..77dcb5f7e --- /dev/null +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/evo_search.yml @@ -0,0 +1,22 @@ +baseline: + timestep_type: ["linspace", "leading", "trailing"] + orders: [3, 2, 1] +parents: + rank_prob: 1.0 + rank_bar: 2 + absolute_bar: 1 +crossover: + prob: 0.15 + better_prob: 0.6 +mutate: + order: + prob: 0.2 + dist: + 1: 0.3 + 2: 0.5 + 3: 0.2 + timestep: + prob: 0.2 + scale: 3 +metric: + indicator: -1 \ No newline at end of file diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/gen_fixed_noise.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/gen_fixed_noise.py new file mode 100644 index 000000000..f2fafe163 --- /dev/null +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/gen_fixed_noise.py @@ -0,0 +1,8 @@ +import paddle +import numpy as np + +paddle.seed(1234) # 固定随机种子 +noise = paddle.randn([5000, 4, 32, 32], dtype="float32") # 举例:形状按你模型需要调整 +# np.save("/path/to/dit_fixed_noise_B5000.npy", noise.numpy()) # 保存为 .npy +# # 或保存为 .pdparams +paddle.save({"fixed_noise": noise}, "/share/chenqian-local/PaddleMIX/ppdiffusers/examples/class_conditional_image_generation/dit_fixed_noise_B5000.pdparams") \ No newline at end of file diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/infer_demo_dit copy.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/infer_demo_dit copy.py new file mode 100644 index 000000000..266f930ff --- /dev/null +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/infer_demo_dit copy.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddlenlp.trainer import set_seed + +from ppdiffusers import DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler + +dtype = paddle.float32 +pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) +# import ipdb; ipdb.set_trace() +# use DDIMScheduler for inference +# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe.scheduler.config.algorithm_type = "dpmsolver" +pipe.scheduler.config.solver_order = 3 +words = ["golden retriever"] # class_ids [207] +class_ids = pipe.get_label_ids(words) +class_ids = [206,207] +import ipdb; ipdb.set_trace() +# import ipdb; ipdb.set_trace() +timesteps_list = [999, 899, 799, 699, 599, 499, 399, 299, 199, 99] +order_list = [1, 2, 3, 1, 1, 2, 2, 2, 2, 1] +# generate image +set_seed(42) +generator = paddle.Generator().manual_seed(0) +image = pipe(class_labels=class_ids, num_inference_steps=10, generator=generator, timesteps_list = timesteps_list, order_list = order_list).images[0] +import ipdb; ipdb.set_trace() +image.save("result_DiT_golden_retriever_dpm_10_2.png") diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/infer_with_result.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/infer_with_result.py new file mode 100644 index 000000000..c930a9cc0 --- /dev/null +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/infer_with_result.py @@ -0,0 +1,40 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddlenlp.trainer import set_seed + +from ppdiffusers import DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler + +dtype = paddle.float32 +pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) +# import ipdb; ipdb.set_trace() +# use DDIMScheduler for inference +# pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) +pipe.scheduler.config.algorithm_type = "dpmsolver" +pipe.scheduler.config.solver_order = 3 +words = ["golden retriever"] # class_ids [207] +class_ids = pipe.get_label_ids(words) +search_result = paddle.load("/share/chenqian-local/PaddleMIX/ppdiffusers/examples/class_conditional_image_generation/DiT/search_result_debug/0/0.pdparams") + +timesteps_list = search_result[0]["timesteps"] +order_list = search_result[0]["orders"] +# import ipdb; ipdb.set_trace() +# generate image +set_seed(42) +generator = paddle.Generator().manual_seed(0) +image = pipe(class_labels=class_ids, num_inference_steps=10, generator=generator, timesteps_list = timesteps_list, order_list = order_list).images[0] +# import ipdb; ipdb.set_trace() +image.save("search_result_DiT_golden_retriever_dpm_10.png") diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/search.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/search.py new file mode 100644 index 000000000..4356f5db5 --- /dev/null +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/search.py @@ -0,0 +1,168 @@ +import paddle +from paddlenlp.trainer import set_seed +import yaml +import argparse + +from ppdiffusers import DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler +import os +import shutil +import random +from search_utils import * + +def load_prompts(prompt_path): + with open(prompt_path, "r") as f: + prompts = [line.strip() for line in f.readlines()] + return prompts + +def eval_coeff(coeff, cfg, generator ,prompts, pipe, save_dir, fixed_noise): + print(f"Begin to evaluate {coeff}") + sample_idx = 0 + + # Clean and create save directory + if os.path.exists(save_dir): + shutil.rmtree(save_dir) + os.makedirs(save_dir, exist_ok=True) + + # Generate samples + for i in range(0, len(prompts), cfg.batch_size): + # import ipdb; ipdb.set_trace() + batch_prompts = prompts[i : i + cfg.batch_size] + + # Sample using DiT with custom orders and timesteps + # samples = scheduler.sample( + # model, + # text_encoder, + # z_size=(vae.out_channels, *latent_size), + # prompts=batch_prompts, + # device=device, + # additional_args=model_args, + # orders=coeff["orders"], + # timesteps=coeff["timesteps"], + # input_noise=fixed_noise[sample_idx:sample_idx+cfg.batch_size], + # ) + # import ipdb; ipdb.set_trace() + samples = pipe(class_labels=batch_prompts, num_inference_steps=10, generator=generator, timesteps_list = coeff["timesteps"], order_list = coeff["orders"], fixed_noise = fixed_noise[sample_idx:sample_idx+cfg.batch_size]).images + # import ipdb; ipdb.set_trace() + + # Decode VAE latents to images + + for idx, sample in enumerate(samples): + print(f"Prompt: {batch_prompts[idx]}") + save_path = os.path.join(save_dir, f"sample_{sample_idx}.png") + sample.save(save_path) + sample_idx += 1 + from paddle_fid.fid_score import calculate_fid_given_paths + fid = calculate_fid_given_paths( + [save_dir,"/share/public-nfs/chenqian/var/dataset/imagenet256/VIRTUAL_imagenet256_labeled.npz"], + batch_size=256, + dims=2048, + num_workers=8, + ) + + return -fid + + + +def main(args): + dtype = paddle.float32 + cfg = args + # import ipdb; ipdb.set_trace() + + fixed_noise = paddle.load("/share/chenqian-local/PaddleMIX/ppdiffusers/examples/class_conditional_image_generation/dit_fixed_noise_B5000.pdparams") + + fixed_noise = fixed_noise["fixed_noise"] + prompts = load_prompts(cfg.prompt_path) + prompts = list(map(int, prompts)) + + pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", paddle_dtype=dtype) + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) + pipe.scheduler.config.algorithm_type = "dpmsolver" + pipe.scheduler.config.solver_order = 3 + generator = paddle.Generator().manual_seed(0) + # set data save path + os.makedirs(os.path.join(cfg.data_path, cfg.split), exist_ok=True) + os.makedirs(os.path.join(cfg.data_path, "baselines"), exist_ok=True) + + # search + with open(cfg.search_config, "r") as f: + search_cfg = yaml.safe_load(f) + baseline_done = False + data_num = len([d for d in os.listdir(os.path.join(cfg.data_path, cfg.split)) if ".pdparams" in str(d)]) + image_save_dir = os.path.join(cfg.data_path, cfg.split, "images") + + # Main search loop + while(1): + print(f"Random check: {random.uniform(0, 1)}") + + # Baseline evaluation phase + if not baseline_done: + baseline, data_path = get_baseline(search_cfg, cfg.budget, cfg.data_path) + if baseline == -1: + baseline_done = True + print("All baselines evaluated, starting evolutionary search...") + else: + score = eval_coeff(baseline, cfg, generator, prompts, pipe, image_save_dir, fixed_noise) + paddle.save([baseline, score], data_path) + + print(f"Save baseline {[baseline, score]} to {data_path}") + + # delete occ file + str_data_path = str(data_path) + assert ".pdparams" in str_data_path + occ_path = str(data_path).replace(".pdparams", ".occ") + if os.path.exists(occ_path): + os.remove(occ_path) + + continue + + # Evolutionary search phase + population = get_population(cfg.data_path, search_cfg) + + # Decide between crossover and mutation + if random.uniform(0, 1) < search_cfg["crossover"]["prob"]: + # Crossover operation + parents_1, parents_2 = select_parents(population, search_cfg, num=2) + print(f"Choose {parents_1} and {parents_2} as the crossover parents") + new_coeff = crossover(parents_1[0], parents_2[0], search_cfg) + else: + new_coeff = None + + if new_coeff is not None: + parent = new_coeff + else: + # Mutation operation + parent = select_parents(population, search_cfg, num=1)[0] + print(f"Choose {parent} as the mutation parent") + + new_coeff = mutate(parent, search_cfg) + + # Evaluate new coefficient + score = eval_coeff(new_coeff, cfg, generator, prompts, pipe, image_save_dir, fixed_noise) + + # Save result + result_path = os.path.join(cfg.data_path, cfg.split, f"{data_num}.pdparams") + paddle.save([new_coeff, score], result_path) + print(f"Save {[new_coeff, score]} to {result_path}") + data_num += 1 + + # Update the search config (in case it was modified externally) + try: + with open(cfg.search_config, "r") as f: + search_cfg = yaml.safe_load(f) + except Exception as e: + print(f"Warning: Could not reload search config: {e}") + +if __name__ == "__main__": + + + parser = argparse.ArgumentParser() + parser.add_argument("--search_config", type=str, default="/share/chenqian-local/PaddleMIX/ppdiffusers/examples/class_conditional_image_generation/DiT/evo_search.yml") + parser.add_argument("--budget", type=int, default=10) + parser.add_argument("--data_path", type=str, default="/share/chenqian-local/PaddleMIX/ppdiffusers/examples/class_conditional_image_generation/DiT/search_result_debug") + parser.add_argument("--split", type=str, default="0") + parser.add_argument("--prompt_path", type=str, default="/share/chenqian-local/PaddleMIX/ppdiffusers/examples/class_conditional_image_generation/DiT/dit_5k.txt") + parser.add_argument("--batch_size", type=int, default=40) + args = parser.parse_args() + + # import ipdb; ipdb.set_trace() + main(args) diff --git a/ppdiffusers/examples/class_conditional_image_generation/DiT/search_utils.py b/ppdiffusers/examples/class_conditional_image_generation/DiT/search_utils.py new file mode 100644 index 000000000..775914bfa --- /dev/null +++ b/ppdiffusers/examples/class_conditional_image_generation/DiT/search_utils.py @@ -0,0 +1,195 @@ +from typing import List +import paddle +import numpy as np +import random +import pathlib +import os +import copy + +# from opensora.schedulers.dpms.dpm_solver import NoiseScheduleVP, get_named_beta_schedule + +# betas = torch.tensor(get_named_beta_schedule("linear", 1000)) +# noise_schedule = NoiseScheduleVP(schedule="discrete", betas=betas) + +# def get_time_steps(skip_type, t_T, t_0, N, device): +# """Compute the intermediate time steps for sampling. + +# Args: +# skip_type: A `str`. The type for the spacing of the time steps. We support three types: +# - 'logSNR': uniform logSNR for the time steps. +# - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) +# - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) +# t_T: A `float`. The starting time of the sampling (default is T). +# t_0: A `float`. The ending time of the sampling (default is epsilon). +# N: A `int`. The total number of the spacing of the time steps. +# device: A torch device. +# Returns: +# A pytorch tensor of the time steps, with the shape (N + 1,). +# """ +# if skip_type == "logSNR": +# lambda_T = noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) +# lambda_0 = noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) +# logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) +# return noise_schedule.inverse_lambda(logSNR_steps) +# elif skip_type == "time_uniform": +# return torch.linspace(t_T, t_0, N + 1).to(device) +# elif skip_type == "time_quadratic": +# t_order = 2 +# return torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device) +# else: +# raise ValueError( +# f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'" +# ) + +def get_time_steps(num_inference_steps, timestep_spacing): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + # Clipping the minimum of all lambda(t) for numerical stability. + # This is critical for cosine (squaredcos_cap_v2) noise schedule. + c = paddle.to_tensor(float("-inf"), dtype=paddle.float32) + lambda_t = paddle.load("/share/chenqian-local/PaddleMIX/ppdiffusers/examples/class_conditional_image_generation/DiT/lambda_t.pdparams") + t = paddle.flip(lambda_t, [0]) + clipped_idx = paddle.searchsorted(t, c) + if paddle.isinf(c): + clipped_idx = paddle.to_tensor(0) + # clipped_idx = paddle.searchsorted(paddle.flip(self.lambda_t, [0]), self.config.lambda_min_clipped) + last_timestep = ((1000 - clipped_idx).numpy()).item() + + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + + + if timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + elif timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += 0 + elif timestep_spacing == "trailing": + step_ratio = 1000 / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + "is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + return timesteps.tolist() + + + +def get_population(base_path, search_cfg): + base_path = pathlib.Path(base_path) + data_paths = [file for file in base_path.glob('**/*.{}'.format("pdparams"))] + population = [] + for path in data_paths: + population.append(paddle.load(str(path))) + print(f"Population size: {len(population)}") + population = sorted(population, key=lambda x: x[1] * search_cfg["metric"].get("indicator", 1)) + + return population + +def select_parents(population, search_cfg, num=1): + # import ipdb;ipdb.set_trace() + if random.uniform(0, 1) < search_cfg["parents"]["rank_prob"]: + mode = "rank" + else: + mode = "absolute" + + if mode == "rank": + all_candidate_parents = population[:search_cfg["parents"]["rank_bar"]] + elif mode == "absolute": + all_candidate_parents = [] + for p in population[1:]: + if abs(p[1] - population[0][1]) < search_cfg["parents"]["absolute_bar"]: + all_candidate_parents.append(p) + else: + break + + selected_parents = random.sample(all_candidate_parents, num) + if num == 1: + return selected_parents[0] + else: + return selected_parents + +def get_baseline(cfg, budget, data_path): + for timestep_type in cfg["baseline"]["timestep_type"]: + for order in cfg["baseline"]["orders"]: + prefix = f"{budget}step_{timestep_type}_order{order}" + if os.path.exists(os.path.join(data_path, "baselines", f"{prefix}.occ")) or \ + os.path.exists(os.path.join(data_path, "baselines", f"{prefix}.pdparams")): + continue + paddle.save("occ", os.path.join(data_path, "baselines", f"{prefix}.occ")) + + # get order list + order_list = [order] * budget + for i in range(order): + order_list[i] = min(i+1, order) + # import ipdb; ipdb.set_trace() + # get timesteps + + timesteps = get_time_steps(num_inference_steps = budget, timestep_spacing=timestep_type) + + return {"orders": order_list, "timesteps": timesteps}, os.path.join(data_path, "baselines", f"{prefix}.pdparams") + + return -1, -1 + +def crossover(parents_1, parents_2, cfg): + # import ipdb; ipdb.set_trace() + new_order_list = [] + for i in range(len(parents_1["orders"])): + if random.uniform(0, 1) < cfg["crossover"]["better_prob"]: + new_order_list.append(parents_1["orders"][i]) + else: + new_order_list.append(parents_2["orders"][i]) + + new_timesteps = [] + for i in range(len(parents_1["timesteps"])): + if random.uniform(0, 1) < cfg["crossover"]["better_prob"]: + new_timesteps.append(parents_1["timesteps"][i]) + else: + new_timesteps.append(parents_2["timesteps"][i]) + + new_timesteps, _ = paddle.sort(paddle.to_tensor(new_timesteps), descending=True) + new_timesteps = new_timesteps.cpu().tolist() + + return {"orders": new_order_list, "timesteps": new_timesteps} + +def mutate(parent, cfg): + # import ipdb; ipdb.set_trace() + new_order_list = copy.deepcopy(parent["orders"]) + new_timesteps = copy.deepcopy(parent["timesteps"]) + + + for i in range(len(new_order_list)): + a = random.uniform(0, 1) + # import ipdb; ipdb.set_trace() + print(a) + if a < cfg["mutate"]["order"]["prob"]: + order_dist = cfg["mutate"]["order"]["dist"] + new_order_list[i] = random.choices(list(order_dist.keys()), weights=list(order_dist.values()))[0] + if i != 0: + if random.uniform(0, 1) < cfg["mutate"]["timestep"]["prob"]: + new_timesteps[i] = new_timesteps[i] + int(round(random.gauss(0.0, 1.0) * cfg["mutate"]["timestep"]["scale"])) + min_time_step = 0 + if new_timesteps[i] < min_time_step: + new_timesteps[i] = round(random.uniform(min_time_step, new_timesteps[i-1])) + # new_timesteps[i] = max(1.0 / noise_schedule.total_N + 5e-4, new_timesteps[i]) + new_order_list[i] = min(i+1, new_order_list[i]) + + new_timesteps = paddle.sort(paddle.to_tensor(new_timesteps), descending=True) + new_timesteps = new_timesteps.cpu().tolist() + + return {"orders": new_order_list, "timesteps": new_timesteps} + \ No newline at end of file diff --git a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py index 7e0e09670..b71fcaf3a 100644 --- a/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py +++ b/ppdiffusers/ppdiffusers/pipelines/dit/pipeline_dit.py @@ -106,6 +106,9 @@ def __call__( num_inference_steps: int = 50, output_type: Optional[str] = "pil", return_dict: bool = True, + order_list: Optional[List[int]] = None, + timesteps_list: Optional[List[int]] = None, + fixed_noise: Optional[paddle.Tensor] = None, ) -> Union[ImagePipelineOutput, Tuple]: r""" The call function to the pipeline for generation. @@ -158,12 +161,15 @@ def __call__( batch_size = len(class_labels) latent_size = self.transformer.config.sample_size latent_channels = self.transformer.config.in_channels - - latents = randn_tensor( - shape=(batch_size, latent_channels, latent_size, latent_size), - generator=generator, - dtype=self.transformer.dtype, - ) + # import ipdb; ipdb.set_trace() + if fixed_noise is not None: + latents = fixed_noise + else: + latents = randn_tensor( + shape=(batch_size, latent_channels, latent_size, latent_size), + generator=generator, + dtype=self.transformer.dtype, + ) latent_model_input = paddle.concat([latents] * 2) if guidance_scale > 1 else latents class_labels = paddle.to_tensor(class_labels).reshape( @@ -175,7 +181,12 @@ def __call__( class_labels_input = paddle.concat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels # set step values - self.scheduler.set_timesteps(num_inference_steps) + if timesteps_list is not None: + self.scheduler.set_timesteps(num_inference_steps, timesteps_list=timesteps_list) + else: + self.scheduler.set_timesteps(num_inference_steps) + # import ipdb; ipdb.set_trace() + orderindex = 0 for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale > 1: half = latent_model_input[: len(latent_model_input) // 2] @@ -225,7 +236,11 @@ def __call__( model_output = noise_pred # compute previous image: x_t -> x_t-1 - latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample + if order_list is not None: + latent_model_input = self.scheduler.step(model_output, t, latent_model_input, order=order_list[orderindex]).prev_sample + orderindex = orderindex + 1 + else: + latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample if guidance_scale > 1: latents, _ = latent_model_input.chunk(2, axis=0) diff --git a/ppdiffusers/ppdiffusers/schedulers/scheduling_dpmsolver_multistep.py b/ppdiffusers/ppdiffusers/schedulers/scheduling_dpmsolver_multistep.py index 1672daa5c..e99aee165 100644 --- a/ppdiffusers/ppdiffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/ppdiffusers/ppdiffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -223,7 +223,7 @@ def step_index(self): """ return self._step_index - def set_timesteps(self, num_inference_steps: int = None): + def set_timesteps(self, num_inference_steps: int = None, timesteps_list: List[int] = None): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -235,6 +235,7 @@ def set_timesteps(self, num_inference_steps: int = None): """ # Clipping the minimum of all lambda(t) for numerical stability. # This is critical for cosine (squaredcos_cap_v2) noise schedule. + # import ipdb; ipdb.set_trace() c = paddle.to_tensor(self.config.lambda_min_clipped) t = paddle.flip(self.lambda_t, [0]) clipped_idx = paddle.searchsorted(t, c) @@ -244,26 +245,29 @@ def set_timesteps(self, num_inference_steps: int = None): last_timestep = ((self.config.num_train_timesteps - clipped_idx).numpy()).item() # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": - timesteps = ( - np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) - ) - elif self.config.timestep_spacing == "leading": - step_ratio = last_timestep // (num_inference_steps + 1) - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) - timesteps -= 1 + if timesteps_list is not None: + timesteps = np.array(timesteps_list, dtype=np.int64) else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) + if self.config.timestep_spacing == "linspace": + timesteps = ( + np.linspace(0, last_timestep - 1, num_inference_steps + 1).round()[::-1][:-1].copy().astype(np.int64) + ) + elif self.config.timestep_spacing == "leading": + step_ratio = last_timestep // (num_inference_steps + 1) + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.arange(last_timestep, 0, -step_ratio).round().copy().astype(np.int64) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) @@ -755,6 +759,7 @@ def multistep_dpm_solver_third_order_update( lambda_s0 = paddle.log(alpha_s0) - paddle.log(sigma_s0) lambda_s1 = paddle.log(alpha_s1) - paddle.log(sigma_s1) lambda_s2 = paddle.log(alpha_s2) - paddle.log(sigma_s2) + # import ipdb; ipdb.set_trace() m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] @@ -806,6 +811,7 @@ def step( sample: paddle.Tensor, generator=None, return_dict: bool = True, + order: int = None, ) -> Union[SchedulerOutput, Tuple]: """ Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with @@ -859,15 +865,23 @@ def step( else: noise = None - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + if order is not None: + if order == 1: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif order == 2: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + elif order == 3: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) else: - prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample) - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 # upon completion increase step index by one self._step_index += 1