Skip to content

Commit 0f0618f

Browse files
committed
refactor the denoiseestep using LoopSequential! also add a new file for denoise step
1 parent d89631f commit 0f0618f

File tree

3 files changed

+1607
-595
lines changed

3 files changed

+1607
-595
lines changed

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 285 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,23 @@ def __init__(self, **kwargs):
184184
for key, value in kwargs.items():
185185
setattr(self, key, value)
186186

187+
def __getitem__(self, key: str):
188+
# allows block_state["foo"]
189+
return getattr(self, key, None)
190+
191+
def __setitem__(self, key: str, value: Any):
192+
# allows block_state["foo"] = "bar"
193+
setattr(self, key, value)
194+
195+
def as_dict(self):
196+
"""
197+
Convert BlockState to a dictionary.
198+
199+
Returns:
200+
Dict[str, Any]: Dictionary containing all attributes of the BlockState
201+
"""
202+
return {key: value for key, value in self.__dict__.items()}
203+
187204
def __repr__(self):
188205
def format_value(v):
189206
# Handle tensors directly
@@ -523,8 +540,12 @@ def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> Li
523540

524541
for block_name, inputs in named_input_lists:
525542
for input_param in inputs:
526-
if input_param.name in combined_dict:
527-
current_param = combined_dict[input_param.name]
543+
if input_param.name is None and input_param.kwargs_type is not None:
544+
input_name = "*_" + input_param.kwargs_type
545+
else:
546+
input_name = input_param.name
547+
if input_name in combined_dict:
548+
current_param = combined_dict[input_name]
528549
if (current_param.default is not None and
529550
input_param.default is not None and
530551
current_param.default != input_param.default):
@@ -557,7 +578,7 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) ->
557578

558579
for block_name, outputs in named_output_lists:
559580
for output_param in outputs:
560-
if output_param.name not in combined_dict:
581+
if (output_param.name not in combined_dict) or (combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None):
561582
combined_dict[output_param.name] = output_param
562583

563584
return list(combined_dict.values())
@@ -919,6 +940,9 @@ def required_intermediates_inputs(self) -> List[str]:
919940
# YiYi TODO: add test for this
920941
@property
921942
def inputs(self) -> List[Tuple[str, Any]]:
943+
return self.get_inputs()
944+
945+
def get_inputs(self):
922946
named_inputs = [(name, block.inputs) for name, block in self.blocks.items()]
923947
combined_inputs = combine_inputs(*named_inputs)
924948
# mark Required inputs only if that input is required any of the blocks
@@ -931,6 +955,9 @@ def inputs(self) -> List[Tuple[str, Any]]:
931955

932956
@property
933957
def intermediates_inputs(self) -> List[str]:
958+
return self.get_intermediates_inputs()
959+
960+
def get_intermediates_inputs(self):
934961
inputs = []
935962
outputs = set()
936963

@@ -1169,7 +1196,262 @@ def doc(self):
11691196
expected_configs=self.expected_configs
11701197
)
11711198

1199+
#YiYi TODO: __repr__
1200+
class LoopSequentialPipelineBlocks(ModularPipelineMixin):
1201+
"""
1202+
A class that combines multiple pipeline block classes into a For Loop. When called, it will call each block in sequence.
1203+
"""
1204+
1205+
model_name = None
1206+
block_classes = []
1207+
block_names = []
1208+
1209+
@property
1210+
def description(self) -> str:
1211+
"""Description of the block. Must be implemented by subclasses."""
1212+
raise NotImplementedError("description method must be implemented in subclasses")
1213+
1214+
@property
1215+
def loop_expected_components(self) -> List[ComponentSpec]:
1216+
return []
1217+
1218+
@property
1219+
def loop_expected_configs(self) -> List[ConfigSpec]:
1220+
return []
1221+
1222+
@property
1223+
def loop_inputs(self) -> List[InputParam]:
1224+
"""List of input parameters. Must be implemented by subclasses."""
1225+
return []
1226+
1227+
@property
1228+
def loop_intermediates_inputs(self) -> List[InputParam]:
1229+
"""List of intermediate input parameters. Must be implemented by subclasses."""
1230+
return []
1231+
1232+
@property
1233+
def loop_intermediates_outputs(self) -> List[OutputParam]:
1234+
"""List of intermediate output parameters. Must be implemented by subclasses."""
1235+
return []
1236+
1237+
1238+
@property
1239+
def loop_required_inputs(self) -> List[str]:
1240+
input_names = []
1241+
for input_param in self.loop_inputs:
1242+
if input_param.required:
1243+
input_names.append(input_param.name)
1244+
return input_names
1245+
1246+
@property
1247+
def loop_required_intermediates_inputs(self) -> List[str]:
1248+
input_names = []
1249+
for input_param in self.loop_intermediates_inputs:
1250+
if input_param.required:
1251+
input_names.append(input_param.name)
1252+
return input_names
1253+
1254+
# modified from SequentialPipelineBlocks to include loop_expected_components
1255+
@property
1256+
def expected_components(self):
1257+
expected_components = []
1258+
for block in self.blocks.values():
1259+
for component in block.expected_components:
1260+
if component not in expected_components:
1261+
expected_components.append(component)
1262+
for component in self.loop_expected_components:
1263+
if component not in expected_components:
1264+
expected_components.append(component)
1265+
return expected_components
1266+
1267+
# modified from SequentialPipelineBlocks to include loop_expected_configs
1268+
@property
1269+
def expected_configs(self):
1270+
expected_configs = []
1271+
for block in self.blocks.values():
1272+
for config in block.expected_configs:
1273+
if config not in expected_configs:
1274+
expected_configs.append(config)
1275+
for config in self.loop_expected_configs:
1276+
if config not in expected_configs:
1277+
expected_configs.append(config)
1278+
return expected_configs
1279+
1280+
# modified from SequentialPipelineBlocks to include loop_inputs
1281+
def get_inputs(self):
1282+
named_inputs = [(name, block.inputs) for name, block in self.blocks.items()]
1283+
named_inputs.append(("loop", self.loop_inputs))
1284+
combined_inputs = combine_inputs(*named_inputs)
1285+
# mark Required inputs only if that input is required any of the blocks
1286+
for input_param in combined_inputs:
1287+
if input_param.name in self.required_inputs:
1288+
input_param.required = True
1289+
else:
1290+
input_param.required = False
1291+
return combined_inputs
1292+
1293+
# Copied from SequentialPipelineBlocks
1294+
@property
1295+
def inputs(self):
1296+
return self.get_inputs()
1297+
1298+
1299+
# modified from SequentialPipelineBlocks to include loop_intermediates_inputs
1300+
@property
1301+
def intermediates_inputs(self):
1302+
intermediates = self.get_intermediates_inputs()
1303+
intermediate_names = [input.name for input in intermediates]
1304+
for loop_intermediate_input in self.loop_intermediates_inputs:
1305+
if loop_intermediate_input.name not in intermediate_names:
1306+
intermediates.append(loop_intermediate_input)
1307+
return intermediates
1308+
1309+
1310+
# Copied from SequentialPipelineBlocks
1311+
def get_intermediates_inputs(self):
1312+
inputs = []
1313+
outputs = set()
1314+
1315+
# Go through all blocks in order
1316+
for block in self.blocks.values():
1317+
# Add inputs that aren't in outputs yet
1318+
inputs.extend(input_name for input_name in block.intermediates_inputs if input_name.name not in outputs)
1319+
1320+
# Only add outputs if the block cannot be skipped
1321+
should_add_outputs = True
1322+
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
1323+
should_add_outputs = False
1324+
1325+
if should_add_outputs:
1326+
# Add this block's outputs
1327+
block_intermediates_outputs = [out.name for out in block.intermediates_outputs]
1328+
outputs.update(block_intermediates_outputs)
1329+
return inputs
1330+
11721331

1332+
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
1333+
@property
1334+
def required_inputs(self) -> List[str]:
1335+
# Get the first block from the dictionary
1336+
first_block = next(iter(self.blocks.values()))
1337+
required_by_any = set(getattr(first_block, "required_inputs", set()))
1338+
1339+
required_by_loop = set(getattr(self, "loop_required_inputs", set()))
1340+
required_by_any.update(required_by_loop)
1341+
1342+
# Union with required inputs from all other blocks
1343+
for block in list(self.blocks.values())[1:]:
1344+
block_required = set(getattr(block, "required_inputs", set()))
1345+
required_by_any.update(block_required)
1346+
1347+
return list(required_by_any)
1348+
1349+
# modified from SequentialPipelineBlocks, if any additional intermediate input required by the loop is required by the block
1350+
@property
1351+
def required_intermediates_inputs(self) -> List[str]:
1352+
required_intermediates_inputs = []
1353+
for input_param in self.intermediates_inputs:
1354+
if input_param.required:
1355+
required_intermediates_inputs.append(input_param.name)
1356+
for input_param in self.loop_intermediates_inputs:
1357+
if input_param.required:
1358+
required_intermediates_inputs.append(input_param.name)
1359+
return required_intermediates_inputs
1360+
1361+
1362+
# YiYi TODO: this need to be thought about more
1363+
# modified from SequentialPipelineBlocks to include loop_intermediates_outputs
1364+
@property
1365+
def intermediates_outputs(self) -> List[str]:
1366+
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
1367+
combined_outputs = combine_outputs(*named_outputs)
1368+
for output in self.loop_intermediates_outputs:
1369+
if output.name not in set([output.name for output in combined_outputs]):
1370+
combined_outputs.append(output)
1371+
return combined_outputs
1372+
1373+
# YiYi TODO: this need to be thought about more
1374+
# copied from SequentialPipelineBlocks
1375+
@property
1376+
def outputs(self) -> List[str]:
1377+
return next(reversed(self.blocks.values())).intermediates_outputs
1378+
1379+
1380+
def __init__(self):
1381+
blocks = OrderedDict()
1382+
for block_name, block_cls in zip(self.block_names, self.block_classes):
1383+
blocks[block_name] = block_cls()
1384+
self.blocks = blocks
1385+
1386+
def loop_step(self, components, state: PipelineState, **kwargs):
1387+
1388+
for block_name, block in self.blocks.items():
1389+
try:
1390+
components, state = block(components, state, **kwargs)
1391+
except Exception as e:
1392+
error_msg = (
1393+
f"\nError in block: ({block_name}, {block.__class__.__name__})\n"
1394+
f"Error details: {str(e)}\n"
1395+
f"Traceback:\n{traceback.format_exc()}"
1396+
)
1397+
logger.error(error_msg)
1398+
raise
1399+
return components, state
1400+
1401+
def __call__(self, components, state: PipelineState) -> PipelineState:
1402+
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
1403+
1404+
1405+
def get_block_state(self, state: PipelineState) -> dict:
1406+
"""Get all inputs and intermediates in one dictionary"""
1407+
data = {}
1408+
1409+
# Check inputs
1410+
for input_param in self.inputs:
1411+
if input_param.name:
1412+
value = state.get_input(input_param.name)
1413+
if input_param.required and value is None:
1414+
raise ValueError(f"Required input '{input_param.name}' is missing")
1415+
elif value is not None or (value is None and input_param.name not in data):
1416+
data[input_param.name] = value
1417+
elif input_param.kwargs_type:
1418+
# if kwargs_type is provided, get all inputs with matching kwargs_type
1419+
if input_param.kwargs_type not in data:
1420+
data[input_param.kwargs_type] = {}
1421+
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
1422+
if inputs_kwargs:
1423+
for k, v in inputs_kwargs.items():
1424+
if v is not None:
1425+
data[k] = v
1426+
data[input_param.kwargs_type][k] = v
1427+
1428+
# Check intermediates
1429+
for input_param in self.intermediates_inputs:
1430+
if input_param.name:
1431+
value = state.get_intermediate(input_param.name)
1432+
if input_param.required and value is None:
1433+
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
1434+
elif value is not None or (value is None and input_param.name not in data):
1435+
data[input_param.name] = value
1436+
elif input_param.kwargs_type:
1437+
# if kwargs_type is provided, get all intermediates with matching kwargs_type
1438+
if input_param.kwargs_type not in data:
1439+
data[input_param.kwargs_type] = {}
1440+
intermediates_kwargs = state.get_intermediates_kwargs(input_param.kwargs_type)
1441+
if intermediates_kwargs:
1442+
for k, v in intermediates_kwargs.items():
1443+
if v is not None:
1444+
if k not in data:
1445+
data[k] = v
1446+
data[input_param.kwargs_type][k] = v
1447+
return BlockState(**data)
1448+
1449+
def add_block_state(self, state: PipelineState, block_state: BlockState):
1450+
for output_param in self.intermediates_outputs:
1451+
if not hasattr(block_state, output_param.name):
1452+
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
1453+
param = getattr(block_state, output_param.name)
1454+
state.add_intermediate(output_param.name, param, output_param.kwargs_type)
11731455

11741456
# YiYi TODO:
11751457
# 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess)

0 commit comments

Comments
 (0)