Skip to content

Commit 8c13362

Browse files
authored
Merge pull request #884 from modelscope/dev2-dzj
Unified Dataset & Splited Training
2 parents c13fd7e + 958ebf1 commit 8c13362

File tree

8 files changed

+598
-18
lines changed

8 files changed

+598
-18
lines changed

diffsynth/pipelines/qwen_image.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,12 @@ def _enable_fp8_lora_training(self, dtype):
174174
computation_dtype=self.torch_dtype,
175175
computation_device="cuda",
176176
)
177-
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
178-
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
179-
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
177+
if self.text_encoder is not None:
178+
enable_vram_management(self.text_encoder, module_map=module_map, module_config=model_config)
179+
if self.dit is not None:
180+
enable_vram_management(self.dit, module_map=module_map, module_config=model_config)
181+
if self.vae is not None:
182+
enable_vram_management(self.vae, module_map=module_map, module_config=model_config)
180183

181184

182185
def enable_vram_management(self, num_persistent_param_in_dit=None, vram_limit=None, vram_buffer=0.5, enable_dit_fp8_computation=False):
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
import torch, torchvision, imageio, os, json, pandas
2+
import imageio.v3 as iio
3+
from PIL import Image
4+
5+
6+
7+
class DataProcessingPipeline:
8+
def __init__(self, operators=None):
9+
self.operators: list[DataProcessingOperator] = [] if operators is None else operators
10+
11+
def __call__(self, data):
12+
for operator in self.operators:
13+
data = operator(data)
14+
return data
15+
16+
def __rshift__(self, pipe):
17+
if isinstance(pipe, DataProcessingOperator):
18+
pipe = DataProcessingPipeline([pipe])
19+
return DataProcessingPipeline(self.operators + pipe.operators)
20+
21+
22+
23+
class DataProcessingOperator:
24+
def __call__(self, data):
25+
raise NotImplementedError("DataProcessingOperator cannot be called directly.")
26+
27+
def __rshift__(self, pipe):
28+
if isinstance(pipe, DataProcessingOperator):
29+
pipe = DataProcessingPipeline([pipe])
30+
return DataProcessingPipeline([self]).__rshift__(pipe)
31+
32+
33+
34+
class DataProcessingOperatorRaw(DataProcessingOperator):
35+
def __call__(self, data):
36+
return data
37+
38+
39+
40+
class ToInt(DataProcessingOperator):
41+
def __call__(self, data):
42+
return int(data)
43+
44+
45+
46+
class ToFloat(DataProcessingOperator):
47+
def __call__(self, data):
48+
return float(data)
49+
50+
51+
52+
class ToStr(DataProcessingOperator):
53+
def __init__(self, none_value=""):
54+
self.none_value = none_value
55+
56+
def __call__(self, data):
57+
if data is None: data = self.none_value
58+
return str(data)
59+
60+
61+
62+
class LoadImage(DataProcessingOperator):
63+
def __init__(self, convert_RGB=True):
64+
self.convert_RGB = convert_RGB
65+
66+
def __call__(self, data: str):
67+
image = Image.open(data)
68+
if self.convert_RGB: image = image.convert("RGB")
69+
return image
70+
71+
72+
73+
class ImageCropAndResize(DataProcessingOperator):
74+
def __init__(self, height, width, max_pixels, height_division_factor, width_division_factor):
75+
self.height = height
76+
self.width = width
77+
self.max_pixels = max_pixels
78+
self.height_division_factor = height_division_factor
79+
self.width_division_factor = width_division_factor
80+
81+
def crop_and_resize(self, image, target_height, target_width):
82+
width, height = image.size
83+
scale = max(target_width / width, target_height / height)
84+
image = torchvision.transforms.functional.resize(
85+
image,
86+
(round(height*scale), round(width*scale)),
87+
interpolation=torchvision.transforms.InterpolationMode.BILINEAR
88+
)
89+
image = torchvision.transforms.functional.center_crop(image, (target_height, target_width))
90+
return image
91+
92+
def get_height_width(self, image):
93+
if self.height is None or self.width is None:
94+
width, height = image.size
95+
if width * height > self.max_pixels:
96+
scale = (width * height / self.max_pixels) ** 0.5
97+
height, width = int(height / scale), int(width / scale)
98+
height = height // self.height_division_factor * self.height_division_factor
99+
width = width // self.width_division_factor * self.width_division_factor
100+
else:
101+
height, width = self.height, self.width
102+
return height, width
103+
104+
105+
def __call__(self, data: Image.Image):
106+
image = self.crop_and_resize(data, *self.get_height_width(data))
107+
return image
108+
109+
110+
111+
class ToList(DataProcessingOperator):
112+
def __call__(self, data):
113+
return [data]
114+
115+
116+
117+
class LoadVideo(DataProcessingOperator):
118+
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
119+
self.num_frames = num_frames
120+
self.time_division_factor = time_division_factor
121+
self.time_division_remainder = time_division_remainder
122+
# frame_processor is build in the video loader for high efficiency.
123+
self.frame_processor = frame_processor
124+
125+
def get_num_frames(self, reader):
126+
num_frames = self.num_frames
127+
if int(reader.count_frames()) < num_frames:
128+
num_frames = int(reader.count_frames())
129+
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
130+
num_frames -= 1
131+
return num_frames
132+
133+
def __call__(self, data: str):
134+
reader = imageio.get_reader(data)
135+
num_frames = self.get_num_frames(reader)
136+
frames = []
137+
for frame_id in range(num_frames):
138+
frame = reader.get_data(frame_id)
139+
frame = Image.fromarray(frame)
140+
frame = self.frame_processor(frame)
141+
frames.append(frame)
142+
reader.close()
143+
return frames
144+
145+
146+
147+
class SequencialProcess(DataProcessingOperator):
148+
def __init__(self, operator=lambda x: x):
149+
self.operator = operator
150+
151+
def __call__(self, data):
152+
return [self.operator(i) for i in data]
153+
154+
155+
156+
class LoadGIF(DataProcessingOperator):
157+
def __init__(self, num_frames=81, time_division_factor=4, time_division_remainder=1, frame_processor=lambda x: x):
158+
self.num_frames = num_frames
159+
self.time_division_factor = time_division_factor
160+
self.time_division_remainder = time_division_remainder
161+
# frame_processor is build in the video loader for high efficiency.
162+
self.frame_processor = frame_processor
163+
164+
def get_num_frames(self, path):
165+
num_frames = self.num_frames
166+
images = iio.imread(path, mode="RGB")
167+
if len(images) < num_frames:
168+
num_frames = len(images)
169+
while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder:
170+
num_frames -= 1
171+
return num_frames
172+
173+
def __call__(self, data: str):
174+
num_frames = self.get_num_frames(data)
175+
frames = []
176+
images = iio.imread(data, mode="RGB")
177+
for img in images:
178+
frame = Image.fromarray(img)
179+
frame = self.frame_processor(frame)
180+
frames.append(frame)
181+
if len(frames) >= num_frames:
182+
break
183+
return frames
184+
185+
186+
187+
class RouteByExtensionName(DataProcessingOperator):
188+
def __init__(self, operator_map):
189+
self.operator_map = operator_map
190+
191+
def __call__(self, data: str):
192+
file_ext_name = data.split(".")[-1].lower()
193+
for ext_names, operator in self.operator_map:
194+
if ext_names is None or file_ext_name in ext_names:
195+
return operator(data)
196+
raise ValueError(f"Unsupported file: {data}")
197+
198+
199+
200+
class RouteByType(DataProcessingOperator):
201+
def __init__(self, operator_map):
202+
self.operator_map = operator_map
203+
204+
def __call__(self, data):
205+
for dtype, operator in self.operator_map:
206+
if dtype is None or isinstance(data, dtype):
207+
return operator(data)
208+
raise ValueError(f"Unsupported data: {data}")
209+
210+
211+
212+
class LoadTorchPickle(DataProcessingOperator):
213+
def __init__(self, map_location="cpu"):
214+
self.map_location = map_location
215+
216+
def __call__(self, data):
217+
return torch.load(data, map_location=self.map_location, weights_only=False)
218+
219+
220+
221+
class ToAbsolutePath(DataProcessingOperator):
222+
def __init__(self, base_path=""):
223+
self.base_path = base_path
224+
225+
def __call__(self, data):
226+
return os.path.join(self.base_path, data)
227+
228+
229+
230+
class UnifiedDataset(torch.utils.data.Dataset):
231+
def __init__(
232+
self,
233+
base_path=None, metadata_path=None,
234+
repeat=1,
235+
data_file_keys=tuple(),
236+
main_data_operator=lambda x: x,
237+
special_operator_map=None,
238+
):
239+
self.base_path = base_path
240+
self.metadata_path = metadata_path
241+
self.repeat = repeat
242+
self.data_file_keys = data_file_keys
243+
self.main_data_operator = main_data_operator
244+
self.cached_data_operator = LoadTorchPickle()
245+
self.special_operator_map = {} if special_operator_map is None else special_operator_map
246+
self.data = []
247+
self.cached_data = []
248+
self.load_from_cache = metadata_path is None
249+
self.load_metadata(metadata_path)
250+
251+
@staticmethod
252+
def default_image_operator(
253+
base_path="",
254+
max_pixels=1920*1080, height=None, width=None,
255+
height_division_factor=16, width_division_factor=16,
256+
):
257+
return RouteByType(operator_map=[
258+
(str, ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
259+
(list, SequencialProcess(ToAbsolutePath(base_path) >> LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor))),
260+
])
261+
262+
@staticmethod
263+
def default_video_operator(
264+
base_path="",
265+
max_pixels=1920*1080, height=None, width=None,
266+
height_division_factor=16, width_division_factor=16,
267+
num_frames=81, time_division_factor=4, time_division_remainder=1,
268+
):
269+
return RouteByType(operator_map=[
270+
(str, ToAbsolutePath(base_path) >> RouteByExtensionName(operator_map=[
271+
(("jpg", "jpeg", "png", "webp"), LoadImage() >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor) >> ToList()),
272+
(("gif",), LoadGIF(num_frames, time_division_factor, time_division_remainder) >> ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor)),
273+
(("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), LoadVideo(
274+
num_frames, time_division_factor, time_division_remainder,
275+
frame_processor=ImageCropAndResize(height, width, max_pixels, height_division_factor, width_division_factor),
276+
)),
277+
])),
278+
])
279+
280+
def search_for_cached_data_files(self, path):
281+
for file_name in os.listdir(path):
282+
subpath = os.path.join(path, file_name)
283+
if os.path.isdir(subpath):
284+
self.search_for_cached_data_files(subpath)
285+
elif subpath.endswith(".pth"):
286+
self.cached_data.append(subpath)
287+
288+
def load_metadata(self, metadata_path):
289+
if metadata_path is None:
290+
print("No metadata_path. Searching for cached data files.")
291+
self.search_for_cached_data_files(self.base_path)
292+
print(f"{len(self.cached_data)} cached data files found.")
293+
elif metadata_path.endswith(".json"):
294+
with open(metadata_path, "r") as f:
295+
metadata = json.load(f)
296+
self.data = metadata
297+
elif metadata_path.endswith(".jsonl"):
298+
metadata = []
299+
with open(metadata_path, 'r') as f:
300+
for line in f:
301+
metadata.append(json.loads(line.strip()))
302+
self.data = metadata
303+
else:
304+
metadata = pandas.read_csv(metadata_path)
305+
self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))]
306+
307+
def __getitem__(self, data_id):
308+
if self.load_from_cache:
309+
data = self.cached_data[data_id % len(self.cached_data)]
310+
data = self.cached_data_operator(data)
311+
else:
312+
data = self.data[data_id % len(self.data)].copy()
313+
for key in self.data_file_keys:
314+
if key in data:
315+
if key in self.special_operator_map:
316+
data[key] = self.special_operator_map[key]
317+
elif key in self.data_file_keys:
318+
data[key] = self.main_data_operator(data[key])
319+
return data
320+
321+
def __len__(self):
322+
if self.load_from_cache:
323+
return len(self.cached_data) * self.repeat
324+
else:
325+
return len(self.data) * self.repeat
326+
327+
def check_data_equal(self, data1, data2):
328+
# Debug only
329+
if len(data1) != len(data2):
330+
return False
331+
for k in data1:
332+
if data1[k] != data2[k]:
333+
return False
334+
return True

0 commit comments

Comments
 (0)