@@ -406,58 +406,143 @@ class IsFullResult(BaseModel):
406406# region Util
407407
408408
409- def create_graph_nfv_tuples (batch : Batch , maximum : int ) -> Generator [tuple [str , str , list [ dict ] ], None , None ]:
409+ def create_session_nfv_tuples (batch : Batch , maximum : int ) -> Generator [tuple [str , str , str ], None , None ]:
410410 """
411- Create all graph permutations from the given batch data and graph. Yields tuples
412- of the form (graph, batch_data_items) where batch_data_items is the list of BatchDataItems
413- that was applied to the graph.
411+ Given a batch and a maximum number of sessions to create, generate a tuple of session_id, session_json, and
412+ field_values_json for each session.
413+
414+ The batch has a "source" graph and a data property. The data property is a list of lists of BatchDatum objects.
415+ Each BatchDatum has a field identifier (e.g. a node id and field name), and a list of values to substitute into
416+ the field.
417+
418+ This structure allows us to create a new graph for every possible permutation of BatchDatum objects:
419+ - Each BatchDatum can be "expanded" into a dict of node-field-value tuples - one for each item in the BatchDatum.
420+ - Zip each inner list of expanded BatchDatum objects together. Call this a "batch_data_list".
421+ - Take the cartesian product of all zipped batch_data_lists, resulting in a list of permutations of BatchDatum
422+ - Take the cartesian product of all zipped batch_data_lists, resulting in a list of lists of BatchDatum objects.
423+ Each inner list now represents the substitution values for a single permutation (session).
424+ - For each permutation, substitute the values into the graph
425+
426+ This function is optimized for performance, as it is used to generate a large number of sessions at once.
427+
428+ Args:
429+ batch: The batch to generate sessions from
430+ maximum: The maximum number of sessions to generate
431+
432+ Returns:
433+ A generator that yields tuples of session_id, session_json, and field_values_json for each session. The
434+ generator will stop early if the maximum number of sessions is reached.
414435 """
415436
416437 # TODO: Should this be a class method on Batch?
417438
418439 data : list [list [tuple [dict ]]] = []
419440 batch_data_collection = batch .data if batch .data is not None else []
420- graph_as_dict = batch .graph .model_dump (warnings = False , exclude_none = True )
421- session_dict = GraphExecutionState (graph = Graph ()).model_dump (warnings = False , exclude_none = True )
422441
423442 for batch_datum_list in batch_data_collection :
424- # each batch_datum_list needs to be convered to NodeFieldValues and then zipped
425-
426443 node_field_values_to_zip : list [list [dict ]] = []
444+ # Expand each BatchDatum into a list of dicts - one for each item in the BatchDatum
427445 for batch_datum in batch_datum_list :
428446 node_field_values = [
447+ # Note: A tuple here is slightly faster than a dict, but we need the object in dict form to be inserted
448+ # in the session_queue table anyways. So, overall creating NFVs as dicts is faster.
429449 {"node_path" : batch_datum .node_path , "field_name" : batch_datum .field_name , "value" : item }
430450 for item in batch_datum .items
431451 ]
432452 node_field_values_to_zip .append (node_field_values )
453+ # Zip the dicts together to create a list of dicts for each permutation
433454 data .append (list (zip (* node_field_values_to_zip , strict = True ))) # type: ignore [arg-type]
434455
435- # create generator to yield session,nfv tuples
456+ # We serialize the graph and session once, then mutate the graph dict in place for each session.
457+ #
458+ # This sounds scary, but it's actually fine.
459+ #
460+ # The batch prep logic injects field values into the same fields for each generated session.
461+ #
462+ # For example, after the product operation, we'll end up with a list of node-field-value tuples like this:
463+ # [
464+ # (
465+ # {"node_path": "1", "field_name": "a", "value": 1},
466+ # {"node_path": "2", "field_name": "b", "value": 2},
467+ # {"node_path": "3", "field_name": "c", "value": 3},
468+ # ),
469+ # (
470+ # {"node_path": "1", "field_name": "a", "value": 4},
471+ # {"node_path": "2", "field_name": "b", "value": 5},
472+ # {"node_path": "3", "field_name": "c", "value": 6},
473+ # )
474+ # ]
475+ #
476+ # Note that each tuple has the same length, and each tuple substitutes values in for exactly the same node fields.
477+ # No matter the complexity of the batch, this property holds true.
478+ #
479+ # This means each permutation's substitution can be done in-place on the same graph dict, because it overwrites the
480+ # previous mutation. We only need to serialize the graph once, and then we can mutate it in place for each session.
481+ #
482+ # Previously, we had created new Graph objects for each session, but this was very slow for large (1k+ session
483+ # batches). We then tried dumping the graph to dict and using deep-copy to create a new dict for each session,
484+ # but this was also slow.
485+ #
486+ # Overall, we achieved a 100x speedup by mutating the graph dict in place for each session over creating new Graph
487+ # objects for each session.
488+ #
489+ # We will also mutate the session dict in place, setting a new ID for each session and setting the mutated graph
490+ # dict as the session's graph.
491+
492+ # Dump the batch's graph to a dict once
493+ graph_as_dict = batch .graph .model_dump (warnings = False , exclude_none = True )
494+
495+ # We must provide a Graph object when creating the "dummy" session dict, but we don't actually use it. It will be
496+ # overwritten for each session by the mutated graph_as_dict.
497+ session_dict = GraphExecutionState (graph = Graph ()).model_dump (warnings = False , exclude_none = True )
498+
499+ # Now we can create a generator that yields the session_id, session_json, and field_values_json for each session.
436500 count = 0
501+
502+ # Each batch may have multiple runs, so we need to generate the same number of sessions for each run. The total is
503+ # still limited by the maximum number of sessions.
437504 for _ in range (batch .runs ):
438505 for d in product (* data ):
439506 if count >= maximum :
507+ # We've reached the maximum number of sessions we may generate
440508 return
509+
510+ # Flatten the list of lists of dicts into a single list of dicts
511+ # TODO(psyche): Is the a more efficient way to do this?
441512 flat_node_field_values = list (chain .from_iterable (d ))
442513
443- # The fields that are injected for each the same for all graphs. Therefore, we can mutate the graph dict
444- # in place and then serialize it to json for each session. It's functionally the same as creating a new
445- # graph dict for each session, but is more efficient.
514+ # Need a fresh ID for each session
446515 session_id = uuid_string ()
516+
517+ # Mutate the session dict in place
447518 session_dict ["id" ] = session_id
448519
449- for item in flat_node_field_values :
450- graph_as_dict ["nodes" ][item ["node_path" ]][item ["field_name" ]] = item ["value" ]
520+ # Substitute the values into the graph
521+ for nfv in flat_node_field_values :
522+ graph_as_dict ["nodes" ][nfv ["node_path" ]][nfv ["field_name" ]] = nfv ["value" ]
451523
524+ # Mutate the session dict in place
452525 session_dict ["graph" ] = graph_as_dict
453- yield (session_id , json .dumps (session_dict , default = to_jsonable_python ), flat_node_field_values )
526+
527+ # Serialize the session and field values
528+ # Note the use of pydantic's to_jsonable_python to handle serialization of any python object, including sets.
529+ session_json = json .dumps (session_dict , default = to_jsonable_python )
530+ field_values_json = json .dumps (flat_node_field_values , default = to_jsonable_python )
531+
532+ # Yield the session_id, session_json, and field_values_json
533+ yield (session_id , session_json , field_values_json )
534+
535+ # Increment the count so we know when to stop
454536 count += 1
455537
456538
457539def calc_session_count (batch : Batch ) -> int :
458540 """
459- Calculates the number of sessions that would be created by the batch, without incurring
460- the overhead of actually generating them. Adapted from `create_sessions().
541+ Calculates the number of sessions that would be created by the batch, without incurring the overhead of actually
542+ creating them, as is done in `create_session_nfv_tuples()`.
543+
544+ The count is used to communicate to the user how many sessions were _requested_ to be created, as opposed to how
545+ many were _actually_ created (which may be less due to the maximum number of sessions).
461546 """
462547 # TODO: Should this be a class method on Batch?
463548 if not batch .data :
@@ -473,20 +558,63 @@ def calc_session_count(batch: Batch) -> int:
473558 return len (data_product ) * batch .runs
474559
475560
476- def prepare_values_to_insert (queue_id : str , batch : Batch , priority : int , max_new_queue_items : int ) -> list [tuple ]:
477- values_to_insert : list [tuple ] = []
561+ ValueToInsertTuple : TypeAlias = tuple [
562+ str , # queue_id
563+ str , # session (as stringified JSON)
564+ str , # session_id
565+ str , # batch_id
566+ str | None , # field_values (optional, as stringified JSON)
567+ int , # priority
568+ str | None , # workflow (optional, as stringified JSON)
569+ str | None , # origin (optional)
570+ str | None , # destination (optional)
571+ str | None , # retried_from_item_id (optional, this is always None for new items)
572+ ]
573+ """A type alias for the tuple of values to insert into the session queue table."""
574+
575+
576+ def prepare_values_to_insert (
577+ queue_id : str , batch : Batch , priority : int , max_new_queue_items : int
578+ ) -> list [ValueToInsertTuple ]:
579+ """
580+ Given a batch, prepare the values to insert into the session queue table. The list of tuples can be used with an
581+ `executemany` statement to insert multiple rows at once.
582+
583+ Args:
584+ queue_id: The ID of the queue to insert the items into
585+ batch: The batch to prepare the values for
586+ priority: The priority of the queue items
587+ max_new_queue_items: The maximum number of queue items to insert
588+
589+ Returns:
590+ A list of tuples to insert into the session queue table. Each tuple contains the following values:
591+ - queue_id
592+ - session (as stringified JSON)
593+ - session_id
594+ - batch_id
595+ - field_values (optional, as stringified JSON)
596+ - priority
597+ - workflow (optional, as stringified JSON)
598+ - origin (optional)
599+ - destination (optional)
600+ - retried_from_item_id (optional, this is always None for new items)
601+ """
602+
603+ # A tuple is a fast and memory-efficient way to store the values to insert. Previously, we used a NamedTuple, but
604+ # measured a ~5% performance improvement by using a normal tuple instead. For very large batches (10k+ items), the
605+ # this difference becomes noticeable.
606+ #
607+ # So, despite the inferior DX with normal tuples, we use one here for performance reasons.
608+
609+ values_to_insert : list [ValueToInsertTuple ] = []
610+
478611 # pydantic's to_jsonable_python handles serialization of any python object, including sets, which json.dumps does
479612 # not support by default. Apparently there are sets somewhere in the graph.
480613
481614 # The same workflow is used for all sessions in the batch - serialize it once
482615 workflow_json = json .dumps (batch .workflow , default = to_jsonable_python ) if batch .workflow else None
483616
484- for session_id , session_json , field_values in create_graph_nfv_tuples (batch , max_new_queue_items ):
485- # As a perf optimization, we can mutate the session_dict in place. This is safe because we dump it to json
486- # as part of the tuple construction
487-
488- field_values_json = json .dumps (field_values , default = to_jsonable_python ) if field_values else None
489-
617+ for session_id , session_json , field_values_json in create_session_nfv_tuples (batch , max_new_queue_items ):
490618 values_to_insert .append (
491619 (
492620 queue_id ,
0 commit comments