Skip to content

Commit 00cae4e

Browse files
committed
docstring doc doc doc
1 parent b3fb418 commit 00cae4e

File tree

2 files changed

+1452
-557
lines changed

2 files changed

+1452
-557
lines changed

src/diffusers/pipelines/modular_pipeline.py

Lines changed: 82 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -228,27 +228,26 @@ def format_intermediates_short(intermediates_inputs: List[InputParam], required_
228228
output_parts.extend(outputs)
229229

230230
# Combine with arrow notation if both inputs and outputs exist
231-
if input_parts and output_parts:
232-
return f"{', '.join(input_parts)} -> {', '.join(output_parts)}"
231+
if output_parts:
232+
return f"-> {', '.join(output_parts)}" if not input_parts else f"{', '.join(input_parts)} -> {', '.join(output_parts)}"
233233
elif input_parts:
234234
return ', '.join(input_parts)
235-
elif output_parts:
236-
return ', '.join(output_parts)
237235
return ""
238236

239237

240-
def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str:
241-
"""Format a list of InputParam objects into a readable string representation.
238+
def format_params(params: List[Union[InputParam, OutputParam]], header: str = "Args", indent_level: int = 4, max_line_length: int = 115) -> str:
239+
"""Format a list of InputParam or OutputParam objects into a readable string representation.
242240
243241
Args:
244-
input_params: List of InputParam objects to format
242+
params: List of InputParam or OutputParam objects to format
243+
header: Header text to use (e.g. "Args" or "Returns")
245244
indent_level: Number of spaces to indent each parameter line (default: 4)
246245
max_line_length: Maximum length for each line before wrapping (default: 115)
247246
248247
Returns:
249-
A formatted string representing all input parameters
248+
A formatted string representing all parameters
250249
"""
251-
if not input_params:
250+
if not params:
252251
return ""
253252

254253
base_indent = " " * indent_level
@@ -270,10 +269,8 @@ def wrap_text(text: str, indent: str, max_length: int) -> str:
270269
current_length = 0
271270

272271
for word in words:
273-
# Calculate word length including space
274272
word_length = len(word) + (1 if current_line else 0)
275273

276-
# Check if adding this word would exceed the max length
277274
if current_line and current_length + word_length > max_length:
278275
lines.append(" ".join(current_line))
279276
current_line = [word]
@@ -285,22 +282,22 @@ def wrap_text(text: str, indent: str, max_length: int) -> str:
285282
if current_line:
286283
lines.append(" ".join(current_line))
287284

288-
# Join lines with proper indentation
289285
return f"\n{indent}".join(lines)
290286

291-
# Add the "Args:" header
292-
formatted_params.append(f"{base_indent}Args:")
287+
# Add the header
288+
formatted_params.append(f"{base_indent}{header}:")
293289

294-
for param in input_params:
290+
for param in params:
295291
# Format parameter name and type
296292
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
297293
param_str = f"{param_indent}{param.name} (`{type_str}`"
298294

299-
# Add optional tag and default value if parameter is optional
300-
if not param.required:
301-
param_str += ", *optional*"
302-
if param.default is not None:
303-
param_str += f", defaults to {param.default}"
295+
# Add optional tag and default value if parameter is an InputParam and optional
296+
if isinstance(param, InputParam):
297+
if not param.required:
298+
param_str += ", *optional*"
299+
if param.default is not None:
300+
param_str += f", defaults to {param.default}"
304301
param_str += "):"
305302

306303
# Add description on a new line with additional indentation and wrapping
@@ -317,84 +314,61 @@ def wrap_text(text: str, indent: str, max_length: int) -> str:
317314

318315
return "\n\n".join(formatted_params)
319316

317+
# Then update the original functions to use this combined version:
318+
def format_input_params(input_params: List[InputParam], indent_level: int = 4, max_line_length: int = 115) -> str:
319+
return format_params(input_params, "Args", indent_level, max_line_length)
320320

321321
def format_output_params(output_params: List[OutputParam], indent_level: int = 4, max_line_length: int = 115) -> str:
322-
"""Format a list of OutputParam objects into a readable string representation.
322+
return format_params(output_params, "Returns", indent_level, max_line_length)
323323

324-
Args:
325-
output_params: List of OutputParam objects to format
326-
indent_level: Number of spaces to indent each parameter line (default: 4)
327-
max_line_length: Maximum length for each line before wrapping (default: 115)
328324

329-
Returns:
330-
A formatted string representing all output parameters
325+
326+
def make_doc_string(inputs, intermediates_inputs, intermediates_outputs, final_intermediates_outputs=None, description=""):
331327
"""
332-
if not output_params:
333-
return ""
334-
335-
base_indent = " " * indent_level
336-
param_indent = " " * (indent_level + 4)
337-
desc_indent = " " * (indent_level + 8)
338-
formatted_params = []
328+
Generates a formatted documentation string describing the pipeline block's parameters and structure.
339329
340-
def get_type_str(type_hint):
341-
if hasattr(type_hint, "__origin__") and type_hint.__origin__ is Union:
342-
types = [t.__name__ if hasattr(t, "__name__") else str(t) for t in type_hint.__args__]
343-
return f"Union[{', '.join(types)}]"
344-
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
345-
346-
def wrap_text(text: str, indent: str, max_length: int) -> str:
347-
"""Wrap text while preserving markdown links and maintaining indentation."""
348-
words = text.split()
349-
lines = []
350-
current_line = []
351-
current_length = 0
330+
Returns:
331+
str: A formatted string containing information about call parameters, intermediate inputs/outputs,
332+
and final intermediate outputs.
333+
"""
334+
output = ""
352335

353-
for word in words:
354-
word_length = len(word) + (1 if current_line else 0)
355-
356-
if current_line and current_length + word_length > max_length:
357-
lines.append(" ".join(current_line))
358-
current_line = [word]
359-
current_length = len(word)
360-
else:
361-
current_line.append(word)
362-
current_length += word_length
363-
364-
if current_line:
365-
lines.append(" ".join(current_line))
366-
367-
return f"\n{indent}".join(lines)
368-
369-
# Add the "Returns:" header
370-
formatted_params.append(f"{base_indent}Returns:")
336+
if description:
337+
desc_lines = description.strip().split('\n')
338+
aligned_desc = '\n'.join(' ' + line for line in desc_lines)
339+
output += aligned_desc + "\n\n"
340+
341+
output += format_input_params(inputs + intermediates_inputs, indent_level=2)
371342

372-
for param in output_params:
373-
# Format parameter name and type
374-
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
375-
param_str = f"{param_indent}{param.name} (`{type_str}`):"
376-
377-
# Add description on a new line with additional indentation and wrapping
378-
if param.description:
379-
desc = re.sub(
380-
r'\[(.*?)\]\((https?://[^\s\)]+)\)',
381-
r'[\1](\2)',
382-
param.description
383-
)
384-
wrapped_desc = wrap_text(desc, desc_indent, max_line_length)
385-
param_str += f"\n{desc_indent}{wrapped_desc}"
386-
387-
formatted_params.append(param_str)
343+
# YiYi TODO: refactor to remove this and `outputs` attribute instead
344+
if final_intermediates_outputs:
345+
output += "\n\n"
346+
output += format_output_params(final_intermediates_outputs, indent_level=2)
347+
348+
if intermediates_outputs:
349+
output += "\n\n------------------------\n"
350+
intermediates_str = format_params(intermediates_outputs, "Intermediates Outputs", indent_level=2)
351+
output += intermediates_str
388352

389-
return "\n\n".join(formatted_params)
353+
elif intermediates_outputs:
354+
output +="\n\n"
355+
output += format_output_params(intermediates_outputs, indent_level=2)
356+
357+
358+
return output
359+
390360

391361
class PipelineBlock:
392362
# YiYi Notes: do we need this?
393363
# pipelie block should set the default value for all expected config/components, so maybe we do not need to explicitly set the list
394364
expected_components = []
395365
expected_configs = []
396366
model_name = None
397-
367+
368+
@property
369+
def description(self) -> str:
370+
return ""
371+
398372
@property
399373
def inputs(self) -> List[InputParam]:
400374
return []
@@ -472,7 +446,7 @@ def __repr__(self):
472446

473447
# Intermediates section
474448
intermediates_str = format_intermediates_short(self.intermediates_inputs, self.required_intermediates_inputs, self.intermediates_outputs)
475-
intermediates = f"Intermediates:\n {intermediates_str}"
449+
intermediates = f"Intermediates(`*` = modified):\n {intermediates_str}"
476450

477451
return (
478452
f"{class_name}(\n"
@@ -484,33 +458,11 @@ def __repr__(self):
484458
f")"
485459
)
486460

487-
def get_doc_string(self):
488-
"""
489-
Generates a formatted documentation string describing the pipeline block's parameters and structure.
490-
491-
Returns:
492-
str: A formatted string containing information about call parameters, intermediate inputs/outputs,
493-
and final intermediate outputs.
494-
"""
495-
output = "Call Parameters:\n"
496-
output += "------------------------\n"
497-
output += format_input_params(self.inputs, indent_level=2)
498-
499-
output += "\n\nIntermediate inputs:\n"
500-
output += "--------------------------\n"
501-
output += format_input_params(self.intermediates_inputs, indent_level=2)
502461

503-
if hasattr(self, "intermediates_outputs"):
504-
output += "\n\nIntermediate outputs:\n"
505-
output += "--------------------------\n"
506-
output += format_output_params(self.intermediates_outputs, indent_level=2)
507-
508-
if hasattr(self, "final_intermediates_outputs"):
509-
output += "\nFinal intermediate outputs:\n"
510-
output += "--------------------------\n"
511-
output += format_output_params(self.final_intermediates_outputs, indent_level=2)
462+
@property
463+
def doc(self):
464+
return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description)
512465

513-
return output
514466

515467
def get_block_state(self, state: PipelineState) -> dict:
516468
"""Get all inputs and intermediates in one dictionary"""
@@ -643,6 +595,10 @@ def __init__(self):
643595
@property
644596
def model_name(self):
645597
return next(iter(self.blocks.values())).model_name
598+
599+
@property
600+
def description(self):
601+
return ""
646602

647603
@property
648604
def expected_components(self):
@@ -849,7 +805,7 @@ def __repr__(self):
849805
inputs_str = format_inputs_short(block.inputs)
850806
sections.append(f" inputs:\n {inputs_str}")
851807

852-
intermediates_str = f" intermediates:\n {format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)}"
808+
intermediates_str = f" intermediates(`*` = modified):\n {format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)}"
853809
sections.append(intermediates_str)
854810

855811
sections.append("")
@@ -861,33 +817,9 @@ def __repr__(self):
861817
f")"
862818
)
863819

864-
def get_doc_string(self):
865-
"""
866-
Generates a formatted documentation string describing the pipeline block's parameters and structure.
867-
868-
Returns:
869-
str: A formatted string containing information about call parameters, intermediate inputs/outputs,
870-
and final intermediate outputs.
871-
"""
872-
output = "Call Parameters:\n"
873-
output += "------------------------\n"
874-
output += format_input_params(self.inputs, indent_level=2)
875-
876-
output += "\n\nIntermediate inputs:\n"
877-
output += "--------------------------\n"
878-
output += format_input_params(self.intermediates_inputs, indent_level=2)
879-
880-
if hasattr(self, "intermediates_outputs"):
881-
output += "\n\nIntermediate outputs:\n"
882-
output += "--------------------------\n"
883-
output += format_output_params(self.intermediates_outputs, indent_level=2)
884-
885-
if hasattr(self, "final_intermediates_outputs"):
886-
output += "\nFinal intermediate outputs:\n"
887-
output += "--------------------------\n"
888-
output += format_output_params(self.final_intermediates_outputs, indent_level=2)
889-
890-
return output
820+
@property
821+
def doc(self):
822+
return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, None, self.description)
891823

892824
class SequentialPipelineBlocks:
893825
"""
@@ -899,6 +831,10 @@ class SequentialPipelineBlocks:
899831
@property
900832
def model_name(self):
901833
return next(iter(self.blocks.values())).model_name
834+
835+
@property
836+
def description(self):
837+
return ""
902838

903839
@property
904840
def expected_components(self):
@@ -1192,7 +1128,7 @@ def __repr__(self):
11921128
intermediates_str = format_intermediates_short(block.intermediates_inputs, block.required_intermediates_inputs, block.intermediates_outputs)
11931129

11941130
if intermediates_str:
1195-
blocks_str += f" intermediates: {intermediates_str}\n"
1131+
blocks_str += f" intermediates(`*` = modified): {intermediates_str}\n"
11961132
blocks_str += "\n"
11971133

11981134
inputs_str = format_inputs_short(self.inputs)
@@ -1220,33 +1156,9 @@ def __repr__(self):
12201156
f")"
12211157
)
12221158

1223-
def get_doc_string(self):
1224-
"""
1225-
Generates a formatted documentation string describing the pipeline block's parameters and structure.
1226-
1227-
Returns:
1228-
str: A formatted string containing information about call parameters, intermediate inputs/outputs,
1229-
and final intermediate outputs.
1230-
"""
1231-
output = "Call Parameters:\n"
1232-
output += "------------------------\n"
1233-
output += format_input_params(self.inputs, indent_level=2)
1234-
1235-
output += "\n\nIntermediate inputs:\n"
1236-
output += "--------------------------\n"
1237-
output += format_input_params(self.intermediates_inputs, indent_level=2)
1238-
1239-
if hasattr(self, "intermediates_outputs"):
1240-
output += "\n\nIntermediate outputs:\n"
1241-
output += "--------------------------\n"
1242-
output += format_output_params(self.intermediates_outputs, indent_level=2)
1243-
1244-
if hasattr(self, "final_intermediates_outputs"):
1245-
output += "\nFinal intermediate outputs:\n"
1246-
output += "--------------------------\n"
1247-
output += format_output_params(self.final_intermediates_outputs, indent_level=2)
1248-
1249-
return output
1159+
@property
1160+
def doc(self):
1161+
return make_doc_string(self.inputs, self.intermediates_inputs, self.intermediates_outputs, self.final_intermediates_outputs, self.description)
12501162

12511163
class ModularPipeline(ConfigMixin):
12521164
"""
@@ -1467,8 +1379,9 @@ def default_call_parameters(self) -> Dict[str, Any]:
14671379
def __repr__(self):
14681380
output = "ModularPipeline:\n"
14691381
output += "==============================\n\n"
1470-
1382+
14711383
block = self.pipeline_block
1384+
14721385
if hasattr(block, "trigger_inputs") and block.trigger_inputs:
14731386
output += "\n"
14741387
output += " Trigger Inputs:\n"
@@ -1514,7 +1427,10 @@ def __repr__(self):
15141427
output += "\n"
15151428

15161429
# List the call parameters
1517-
output += self.pipeline_block.get_doc_string()
1430+
full_doc = self.pipeline_block.doc
1431+
if "------------------------" in full_doc:
1432+
full_doc = full_doc.split("------------------------")[0].rstrip()
1433+
output += full_doc
15181434

15191435
return output
15201436

0 commit comments

Comments
 (0)