Skip to content

Commit b640c39

Browse files
authored
Update accel_sdxl_gen_img.py
1 parent d69619c commit b640c39

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

accel_sdxl_gen_img.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2927,7 +2927,13 @@ def scale_and_round(x):
29272927
elif len(split_into_batches) == 1 :
29282928
sublist.extend(split_into_batches.pop(-1))
29292929
split_into_batches = []
2930-
# sublist = sorted(sublist, key=lambda x: x.global_count)
2930+
sublist = sorted(sublist, key=lambda x: x.global_count)
2931+
resorted_list = []
2932+
for i in range(distributed_state.num_processes):
2933+
resorted_list.append(sublist[i :: distributed_state.num_processes])
2934+
sublist = []
2935+
for list_of_prompts in resorted_list:
2936+
sublist.extend(list_of_prompts)
29312937

29322938
n, m = divmod(len(sublist), distributed_state.num_processes)
29332939
split_into_batches.extend([sublist[i*n+min(i,m):(i+1)*n+min(i+1,m)] for i in range(distributed_state.num_processes)])

0 commit comments

Comments
 (0)