Skip to content

Commit 047c643

Browse files
tidy(app): document & clean up batch prep logic
1 parent d1e03aa commit 047c643

File tree

2 files changed

+159
-31
lines changed

2 files changed

+159
-31
lines changed

invokeai/app/services/session_queue/session_queue_common.py

Lines changed: 153 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -406,58 +406,143 @@ class IsFullResult(BaseModel):
406406
# region Util
407407

408408

409-
def create_graph_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, list[dict]], None, None]:
409+
def create_session_nfv_tuples(batch: Batch, maximum: int) -> Generator[tuple[str, str, str], None, None]:
410410
"""
411-
Create all graph permutations from the given batch data and graph. Yields tuples
412-
of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
413-
that was applied to the graph.
411+
Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
412+
field_values_json for each session.
413+
414+
The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
415+
Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
416+
the field.
417+
418+
This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
419+
- Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
420+
- Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
421+
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
422+
- Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
423+
Each inner list now represents the substitution values for a single permutation (session).
424+
- For each permutation, substitute the values into the graph
425+
426+
This function is optimized for performance, as it is used to generate a large number of sessions at once.
427+
428+
Args:
429+
batch: The batch to generate sessions from
430+
maximum: The maximum number of sessions to generate
431+
432+
Returns:
433+
A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
434+
generator will stop early if the maximum number of sessions is reached.
414435
"""
415436

416437
# TODO: Should this be a class method on Batch?
417438

418439
data: list[list[tuple[dict]]] = []
419440
batch_data_collection = batch.data if batch.data is not None else []
420-
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
421-
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
422441

423442
for batch_datum_list in batch_data_collection:
424-
# each batch_datum_list needs to be convered to NodeFieldValues and then zipped
425-
426443
node_field_values_to_zip: list[list[dict]] = []
444+
# Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
427445
for batch_datum in batch_datum_list:
428446
node_field_values = [
447+
# Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
448+
# in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
429449
{"node_path": batch_datum.node_path, "field_name": batch_datum.field_name, "value": item}
430450
for item in batch_datum.items
431451
]
432452
node_field_values_to_zip.append(node_field_values)
453+
# Zip the dicts together to create a list of dicts for each permutation
433454
data.append(list(zip(*node_field_values_to_zip, strict=True))) # type: ignore [arg-type]
434455

435-
# create generator to yield session,nfv tuples
456+
# We serialize the graph and session once, then mutate the graph dict in place for each session.
457+
#
458+
# This sounds scary, but it's actually fine.
459+
#
460+
# The batch prep logic injects field values into the same fields for each generated session.
461+
#
462+
# For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
463+
# [
464+
# (
465+
# {"node_path": "1", "field_name": "a", "value": 1},
466+
# {"node_path": "2", "field_name": "b", "value": 2},
467+
# {"node_path": "3", "field_name": "c", "value": 3},
468+
# ),
469+
# (
470+
# {"node_path": "1", "field_name": "a", "value": 4},
471+
# {"node_path": "2", "field_name": "b", "value": 5},
472+
# {"node_path": "3", "field_name": "c", "value": 6},
473+
# )
474+
# ]
475+
#
476+
# Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
477+
# No matter the complexity of the batch, this property holds true.
478+
#
479+
# This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
480+
# previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
481+
#
482+
# Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
483+
# batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
484+
# but this was also slow.
485+
#
486+
# Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
487+
# objects for each session.
488+
#
489+
# We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
490+
# dict as the session's graph.
491+
492+
# Dump the batch's graph to a dict once
493+
graph_as_dict = batch.graph.model_dump(warnings=False, exclude_none=True)
494+
495+
# We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
496+
# overwritten for each session by the mutated graph_as_dict.
497+
session_dict = GraphExecutionState(graph=Graph()).model_dump(warnings=False, exclude_none=True)
498+
499+
# Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
436500
count = 0
501+
502+
# Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
503+
# still limited by the maximum number of sessions.
437504
for _ in range(batch.runs):
438505
for d in product(*data):
439506
if count >= maximum:
507+
# We've reached the maximum number of sessions we may generate
440508
return
509+
510+
# Flatten the list of lists of dicts into a single list of dicts
511+
# TODO(psyche): Is the a more efficient way to do this?
441512
flat_node_field_values = list(chain.from_iterable(d))
442513

443-
# The fields that are injected for each the same for all graphs. Therefore, we can mutate the graph dict
444-
# in place and then serialize it to json for each session. It's functionally the same as creating a new
445-
# graph dict for each session, but is more efficient.
514+
# Need a fresh ID for each session
446515
session_id = uuid_string()
516+
517+
# Mutate the session dict in place
447518
session_dict["id"] = session_id
448519

449-
for item in flat_node_field_values:
450-
graph_as_dict["nodes"][item["node_path"]][item["field_name"]] = item["value"]
520+
# Substitute the values into the graph
521+
for nfv in flat_node_field_values:
522+
graph_as_dict["nodes"][nfv["node_path"]][nfv["field_name"]] = nfv["value"]
451523

524+
# Mutate the session dict in place
452525
session_dict["graph"] = graph_as_dict
453-
yield (session_id, json.dumps(session_dict, default=to_jsonable_python), flat_node_field_values)
526+
527+
# Serialize the session and field values
528+
# Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
529+
session_json = json.dumps(session_dict, default=to_jsonable_python)
530+
field_values_json = json.dumps(flat_node_field_values, default=to_jsonable_python)
531+
532+
# Yield the session_id, session_json, and field_values_json
533+
yield (session_id, session_json, field_values_json)
534+
535+
# Increment the count so we know when to stop
454536
count += 1
455537

456538

457539
def calc_session_count(batch: Batch) -> int:
458540
"""
459-
Calculates the number of sessions that would be created by the batch, without incurring
460-
the overhead of actually generating them. Adapted from `create_sessions().
541+
Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
542+
creating them, as is done in `create_session_nfv_tuples()`.
543+
544+
The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
545+
many were _actually_ created (which may be less due to the maximum number of sessions).
461546
"""
462547
# TODO: Should this be a class method on Batch?
463548
if not batch.data:
@@ -473,20 +558,63 @@ def calc_session_count(batch: Batch) -> int:
473558
return len(data_product) * batch.runs
474559

475560

476-
def prepare_values_to_insert(queue_id: str, batch: Batch, priority: int, max_new_queue_items: int) -> list[tuple]:
477-
values_to_insert: list[tuple] = []
561+
ValueToInsertTuple: TypeAlias = tuple[
562+
str, # queue_id
563+
str, # session (as stringified JSON)
564+
str, # session_id
565+
str, # batch_id
566+
str | None, # field_values (optional, as stringified JSON)
567+
int, # priority
568+
str | None, # workflow (optional, as stringified JSON)
569+
str | None, # origin (optional)
570+
str | None, # destination (optional)
571+
str | None, # retried_from_item_id (optional, this is always None for new items)
572+
]
573+
"""A type alias for the tuple of values to insert into the session queue table."""
574+
575+
576+
def prepare_values_to_insert(
577+
queue_id: str, batch: Batch, priority: int, max_new_queue_items: int
578+
) -> list[ValueToInsertTuple]:
579+
"""
580+
Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
581+
`executemany` statement to insert multiple rows at once.
582+
583+
Args:
584+
queue_id: The ID of the queue to insert the items into
585+
batch: The batch to prepare the values for
586+
priority: The priority of the queue items
587+
max_new_queue_items: The maximum number of queue items to insert
588+
589+
Returns:
590+
A list of tuples to insert into the session queue table. Each tuple contains the following values:
591+
- queue_id
592+
- session (as stringified JSON)
593+
- session_id
594+
- batch_id
595+
- field_values (optional, as stringified JSON)
596+
- priority
597+
- workflow (optional, as stringified JSON)
598+
- origin (optional)
599+
- destination (optional)
600+
- retried_from_item_id (optional, this is always None for new items)
601+
"""
602+
603+
# A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
604+
# measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
605+
# this difference becomes noticeable.
606+
#
607+
# So, despite the inferior DX with normal tuples, we use one here for performance reasons.
608+
609+
values_to_insert: list[ValueToInsertTuple] = []
610+
478611
# pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
479612
# not support by default. Apparently there are sets somewhere in the graph.
480613

481614
# The same workflow is used for all sessions in the batch - serialize it once
482615
workflow_json = json.dumps(batch.workflow, default=to_jsonable_python) if batch.workflow else None
483616

484-
for session_id, session_json, field_values in create_graph_nfv_tuples(batch, max_new_queue_items):
485-
# As a perf optimization, we can mutate the session_dict in place. This is safe because we dump it to json
486-
# as part of the tuple construction
487-
488-
field_values_json = json.dumps(field_values, default=to_jsonable_python) if field_values else None
489-
617+
for session_id, session_json, field_values_json in create_session_nfv_tuples(batch, max_new_queue_items):
490618
values_to_insert.append(
491619
(
492620
queue_id,

tests/test_session_queue.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
BatchDatum,
1010
NodeFieldValue,
1111
calc_session_count,
12-
create_graph_nfv_tuples,
12+
create_session_nfv_tuples,
1313
prepare_values_to_insert,
1414
)
1515
from invokeai.app.services.shared.graph import Graph, GraphExecutionState
@@ -42,7 +42,7 @@ def batch_graph() -> Graph:
4242

4343
def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph):
4444
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
45-
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
45+
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
4646
# 2 list[BatchDatum] * length 2 * 2 runs = 8
4747
assert len(t) == 8
4848

@@ -90,28 +90,28 @@ def test_create_sessions_from_batch_with_runs(batch_data_collection, batch_graph
9090

9191
def test_create_sessions_from_batch_without_runs(batch_data_collection, batch_graph):
9292
b = Batch(graph=batch_graph, data=batch_data_collection)
93-
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
93+
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
9494
# 2 list[BatchDatum] * length 2 * 1 runs = 8
9595
assert len(t) == 4
9696

9797

9898
def test_create_sessions_from_batch_without_batch(batch_graph):
9999
b = Batch(graph=batch_graph, runs=2)
100-
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
100+
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
101101
# 2 runs
102102
assert len(t) == 2
103103

104104

105105
def test_create_sessions_from_batch_without_batch_or_runs(batch_graph):
106106
b = Batch(graph=batch_graph)
107-
t = list(create_graph_nfv_tuples(batch=b, maximum=1000))
107+
t = list(create_session_nfv_tuples(batch=b, maximum=1000))
108108
# 1 run
109109
assert len(t) == 1
110110

111111

112112
def test_create_sessions_from_batch_with_runs_and_max(batch_data_collection, batch_graph):
113113
b = Batch(graph=batch_graph, data=batch_data_collection, runs=2)
114-
t = list(create_graph_nfv_tuples(batch=b, maximum=5))
114+
t = list(create_session_nfv_tuples(batch=b, maximum=5))
115115
# 2 list[BatchDatum] * length 2 * 2 runs = 8, but max is 5
116116
assert len(t) == 5
117117

0 commit comments

Comments
 (0)