diff --git a/README.md b/README.md index b14da492..0810fda4 100644 --- a/README.md +++ b/README.md @@ -257,17 +257,17 @@ disjunctive_graph = build_disjunctive_graph(instance) >>> disjunctive_graph.nodes_by_type defaultdict(list, - {: [Node(node_type=OPERATION, value=O(m=0, d=1, j=0, p=0), id=0), - Node(node_type=OPERATION, value=O(m=1, d=1, j=0, p=1), id=1), - Node(node_type=OPERATION, value=O(m=2, d=7, j=0, p=2), id=2), - Node(node_type=OPERATION, value=O(m=1, d=5, j=1, p=0), id=3), - Node(node_type=OPERATION, value=O(m=2, d=1, j=1, p=1), id=4), - Node(node_type=OPERATION, value=O(m=0, d=1, j=1, p=2), id=5), - Node(node_type=OPERATION, value=O(m=2, d=1, j=2, p=0), id=6), - Node(node_type=OPERATION, value=O(m=0, d=3, j=2, p=1), id=7), - Node(node_type=OPERATION, value=O(m=1, d=2, j=2, p=2), id=8)], - : [Node(node_type=SOURCE, value=None, id=9)], - : [Node(node_type=SINK, value=None, id=10)]}) + {: [Node(node_type=OPERATION, value=O(m=0, d=1, j=0, p=0), id=("OPERATION", 0)), + Node(node_type=OPERATION, value=O(m=1, d=1, j=0, p=1), id=("OPERATION", 1)), + Node(node_type=OPERATION, value=O(m=2, d=7, j=0, p=2), id=("OPERATION", 2)), + Node(node_type=OPERATION, value=O(m=1, d=5, j=1, p=0), id=("OPERATION", 3)), + Node(node_type=OPERATION, value=O(m=2, d=1, j=1, p=1), id=("OPERATION", 4)), + Node(node_type=OPERATION, value=O(m=0, d=1, j=1, p=2), id=("OPERATION", 5)), + Node(node_type=OPERATION, value=O(m=2, d=1, j=2, p=0), id=("OPERATION", 6)), + Node(node_type=OPERATION, value=O(m=0, d=3, j=2, p=1), id=("OPERATION", 7)), + Node(node_type=OPERATION, value=O(m=1, d=2, j=2, p=2), id=("OPERATION", 8))], + : [Node(node_type=SOURCE, value=None, id=('SOURCE', 0))], + : [Node(node_type=SINK, value=None, id=('SINK', 0))]}) ``` Other attributes include: diff --git a/docs/source/examples/05-Load-Benchmark-Instances.ipynb b/docs/source/examples/05-Load-Benchmark-Instances.ipynb index 5fb90990..e97e3140 100644 --- a/docs/source/examples/05-Load-Benchmark-Instances.ipynb +++ b/docs/source/examples/05-Load-Benchmark-Instances.ipynb @@ -303,6 +303,7 @@ ], "source": [ "import numpy as np\n", + "\n", "np.array(ft06.durations_matrix)" ] }, diff --git a/docs/source/examples/09-SingleJobShopGraphEnv.ipynb b/docs/source/examples/09-SingleJobShopGraphEnv.ipynb index 0bb8c255..7e2961ae 100644 --- a/docs/source/examples/09-SingleJobShopGraphEnv.ipynb +++ b/docs/source/examples/09-SingleJobShopGraphEnv.ipynb @@ -49,9 +49,7 @@ " feature_observer_configs=feature_observer_configs,\n", " reward_function_config=DispatcherObserverConfig(IdleTimeReward),\n", " render_mode=\"human\", # Try \"save_video\"\n", - " render_config={\n", - " \"video_config\": {\"fps\": 4}\n", - " }\n", + " render_config={\"video_config\": {\"fps\": 4}},\n", ")" ] }, @@ -252,7 +250,9 @@ "import numpy as np\n", "\n", "rewards = np.array(env.reward_function.rewards)\n", - "print(f\"{len(list(filter(lambda x: x != 0, rewards)))} zeros out of {len(rewards)}\")" + "print(\n", + " f\"{len(list(filter(lambda x: x != 0, rewards)))} zeros out of {len(rewards)}\"\n", + ")" ] }, { diff --git a/docs/source/examples/12-ReadyOperationsFilter.ipynb b/docs/source/examples/12-ReadyOperationsFilter.ipynb index f36aced9..811c2c27 100644 --- a/docs/source/examples/12-ReadyOperationsFilter.ipynb +++ b/docs/source/examples/12-ReadyOperationsFilter.ipynb @@ -375,7 +375,13 @@ " df = print_makespans_for_optimizations(ta01, pdr)\n", " print()\n", " # Chamge the number of xticks to 10\n", - " df.sort_values(by=\"Makespan\", ascending=False).plot.barh(x=\"Filter Combination\", y=\"Makespan\", title=pdr.name, xlim=(1500, 2000), xticks=range(1500, 2001, 50))\n", + " df.sort_values(by=\"Makespan\", ascending=False).plot.barh(\n", + " x=\"Filter Combination\",\n", + " y=\"Makespan\",\n", + " title=pdr.name,\n", + " xlim=(1500, 2000),\n", + " xticks=range(1500, 2001, 50),\n", + " )\n", " dfs.append(df)" ] }, @@ -408,7 +414,9 @@ "source": [ "df_combined = pd.concat(dfs, keys=[pdr.name for pdr in DispatchingRuleType])\n", "\n", - "df_combined.groupby(\"Filter Combination\").agg(\"mean\").sort_values(by=\"Makespan\", ascending=False).plot.barh(title=\"Average Makespan\")" + "df_combined.groupby(\"Filter Combination\").agg(\"mean\").sort_values(\n", + " by=\"Makespan\", ascending=False\n", + ").plot.barh(title=\"Average Makespan\")" ] }, { diff --git a/job_shop_lib/.DS_Store b/job_shop_lib/.DS_Store new file mode 100644 index 00000000..3fc0afc8 Binary files /dev/null and b/job_shop_lib/.DS_Store differ diff --git a/job_shop_lib/__init__.py b/job_shop_lib/__init__.py index 2b8ab5b8..5ed1c324 100644 --- a/job_shop_lib/__init__.py +++ b/job_shop_lib/__init__.py @@ -19,7 +19,7 @@ from job_shop_lib._base_solver import BaseSolver, Solver -__version__ = "1.6.1" +__version__ = "1.7.0" __all__ = [ "Operation", diff --git a/job_shop_lib/dispatching/feature_observers/_dates_observer.py b/job_shop_lib/dispatching/feature_observers/_dates_observer.py index d1839b27..2c8540ce 100644 --- a/job_shop_lib/dispatching/feature_observers/_dates_observer.py +++ b/job_shop_lib/dispatching/feature_observers/_dates_observer.py @@ -135,8 +135,7 @@ def update(self, scheduled_operation: ScheduledOperation): elapsed_time = current_time - self._previous_current_time self._previous_current_time = current_time cols = [ - self._attribute_map[attr] - for attr in self.attributes_to_observe + self._attribute_map[attr] for attr in self.attributes_to_observe ] self.features[FeatureType.OPERATIONS][:, cols] -= elapsed_time diff --git a/job_shop_lib/graphs/_build_disjunctive_graph.py b/job_shop_lib/graphs/_build_disjunctive_graph.py index 32353246..c2c663f6 100644 --- a/job_shop_lib/graphs/_build_disjunctive_graph.py +++ b/job_shop_lib/graphs/_build_disjunctive_graph.py @@ -81,8 +81,12 @@ def build_solved_disjunctive_graph(schedule: Schedule) -> JobShopGraph: break next_scheduled_operation = machine_schedule[i + 1] graph.add_edge( - scheduled_operation.operation.operation_id, - next_scheduled_operation.operation.operation_id, + graph.get_operation_node( + scheduled_operation.operation.operation_id + ), + graph.get_operation_node( + next_scheduled_operation.operation.operation_id + ), type=EdgeType.DISJUNCTIVE, ) diff --git a/job_shop_lib/graphs/_job_shop_graph.py b/job_shop_lib/graphs/_job_shop_graph.py index 2ca18a29..ea71eafa 100644 --- a/job_shop_lib/graphs/_job_shop_graph.py +++ b/job_shop_lib/graphs/_job_shop_graph.py @@ -1,7 +1,9 @@ """Home of the `JobShopGraph` class.""" - +from typing import DefaultDict import collections + import networkx as nx +import numpy as np from job_shop_lib import JobShopInstance from job_shop_lib.exceptions import ValidationError @@ -15,16 +17,20 @@ class JobShopGraph: """Represents a :class:`JobShopInstance` as a heterogeneous directed graph. - Provides a comprehensive graph-based representation of a job shop - scheduling problem, utilizing the ``networkx`` library to model the complex - relationships between jobs, operations, and machines. This class transforms - the abstract scheduling problem into a directed graph, where various - entities (jobs, machines, and operations) are nodes, and the dependencies - (such as operation order within a job or machine assignment) are edges. + Internally, the graph is represented using adjacency lists + (``adjaceny_in`` and ``adjacency_out``) for efficient + addition and removal of nodes and edges. + + The graph can be converted to a + :class:`networkx.DiGraph` on-demand via the + :meth:`~JobShopGraph.get_networkx_graph` method. This transformation allows for the application of graph algorithms to analyze and solve scheduling problems. + The class generates and manages node identifiers as tuples of the + form `(node_type_name, local_id)`, e.g., `("operation", 42)`. + Args: instance: The job shop instance that the graph represents. @@ -36,33 +42,63 @@ class JobShopGraph: __slots__ = { "instance": "The job shop instance that the graph represents.", - "graph": ( - "The directed graph representing the job shop, where nodes are " - "operations, machines, jobs, or abstract concepts like global, " - "source, and sink, with edges indicating dependencies." - ), "_nodes": "List of all nodes added to the graph.", + "_nodes_map": ( + "Dictionary mapping node ids to nodes for quick access." + ), "_nodes_by_type": "Dictionary mapping node types to lists of nodes.", "_nodes_by_machine": ( "List of lists mapping machine ids to operation nodes." ), "_nodes_by_job": "List of lists mapping job ids to operation nodes.", - "_next_node_id": ( - "The id to assign to the next node added to thegraph." - ), "removed_nodes": ( - "List of boolean values indicating whether a node has been " - "removed from the graph." + "Dictionary mapping instance ids to a boolean indicating whether " + "a node has been removed." + "The keys are node types, and the values are lists mapping " + "instance ids to booleans. This allows for quick access " + "to removed nodes by their instance ids." + ), + "instance_id_map": ( + "Dictionary mapping instance ids to nodes for quick access." + "The keys are node types, and the values are dictionaries mapping " + "instance ids to nodes. This allows for quick access to " + "nodes by their operation, machine, or job ids." + ), + "adjacency_in": ( + "Stores graph adjacency information of incoming edges," + "mapping nodes to their neighbors based on edge types. The " + "keys are either edge types or tuples of (source_node_type, " + "'to', destination_node_type), and the values are lists of " + "nodes that are connected to the key node type or tuple." + "In case of conjunctive or disjunctive edges, these edge types" + "will replace the 'to' component of the type tuple" + ), + "adjacency_out": ( + "Stores graph adjacency information of outgoing edges," + "mapping nodes to their neighbors based on edge types. The " + "keys are either edge types or tuples of (source_node_type, " + "'to', destination_node_type), and the values are lists of " + "nodes that are connected to the key node type or tuple." + "In case of conjunctive or disjunctive edges, these edge types" + "will replace the 'to' component of the type tuple" + ), + "edge_types": ( + "A set of all edge types present in the graph." + "Only includes tuples of " + "(source_node_type, 'to', destination_node_type)," + "processing conjunctive and disjunctive edges, " + "replacing the 'to' component of the type tuple" + "with the appropriate edge type" ), } def __init__( self, instance: JobShopInstance, add_operation_nodes: bool = True ): - self.graph = nx.DiGraph() self.instance = instance self._nodes: list[Node] = [] + self._nodes_map: dict[tuple[str, int], Node] = {} self._nodes_by_type: dict[NodeType, list[Node]] = ( collections.defaultdict(list) ) @@ -72,47 +108,85 @@ def __init__( self._nodes_by_job: list[list[Node]] = [ [] for _ in range(instance.num_jobs) ] - self._next_node_id = 0 - self.removed_nodes: list[bool] = [] + # Changed: _next_node_id is now removed + # self._next_node_id = collections.defaultdict(int) + # Changed: removed_nodes is now a dictionary of lists + self.removed_nodes: dict[str, list[bool]] = collections.defaultdict( + list + ) + self.instance_id_map: dict[str, dict[int, Node]] = ( + collections.defaultdict(dict) + ) + self.adjacency_in: dict[ + Node, + dict[tuple[str, str, str], list[Node]], + ] = {} + self.adjacency_out: dict[ + Node, + dict[tuple[str, str, str], list[Node]], + ] = {} + if add_operation_nodes: self.add_operation_nodes() + self.edge_types = set[tuple[str, str, str]]() - @property - def nodes(self) -> list[Node]: - """List of all nodes added to the graph. + def get_networkx_graph(self) -> nx.DiGraph: + """Constructs and returns a ``networkx.DiGraph`` object on-demand. - It may contain nodes that have been removed from the graph. + Each node is represented by their node ID, and all edges have a "type" + property in their data. """ + g = nx.DiGraph() + # Add only the nodes that have not been removed + for node_obj in self.non_removed_nodes(): + g.add_node(node_obj.node_id, **{NODE_ATTR: node_obj}) + + # Add edges as edges from removed nodes are not included + # in the graph, as the process of removing nodes also removes + # all edges connected to them. + for node, neighbors in self.adjacency_out.items(): + for edge_type, neighbor_nodes in neighbors.items(): + for neighbor in neighbor_nodes: + g.add_edge( + node.node_id, + neighbor.node_id, + type=edge_type, + ) + return g + + @property + def nodes(self) -> list[Node]: + """List of all nodes added to the graph.""" return self._nodes @property - def nodes_by_type(self) -> dict[NodeType, list[Node]]: - """Dictionary mapping node types to lists of nodes. + def nodes_map(self) -> dict[tuple[str, int], Node]: + """Dictionary mapping node ids to nodes for quick access.""" + return self._nodes_map - It may contain nodes that have been removed from the graph. - """ + @property + def nodes_by_type(self) -> dict[NodeType, list[Node]]: + """Dictionary mapping node types to lists of nodes.""" return self._nodes_by_type @property def nodes_by_machine(self) -> list[list[Node]]: - """List of lists mapping machine ids to operation nodes. - - It may contain nodes that have been removed from the graph. - """ + """List of lists mapping machine ids to operation nodes.""" return self._nodes_by_machine @property def nodes_by_job(self) -> list[list[Node]]: - """List of lists mapping job ids to operation nodes. - - It may contain nodes that have been removed from the graph. - """ + """List of lists mapping job ids to operation nodes.""" return self._nodes_by_job @property def num_edges(self) -> int: """Number of edges in the graph.""" - return self.graph.number_of_edges() + return sum( + len(neighbors) + for edges_by_type in self.adjacency_out.values() + for neighbors in edges_by_type.values() + ) @property def num_job_nodes(self) -> int: @@ -130,10 +204,10 @@ def add_node(self, node_for_adding: Node) -> None: """Adds a node to the graph and updates relevant class attributes. This method assigns a unique identifier to the node, adds it to the - graph, and updates the nodes list and the nodes_by_type dictionary. If - the node is of type :class:`NodeType.OPERATION`, it also updates - ``nodes_by_job`` and ``nodes_by_machine`` based on the operation's - job id and machine ids. + graph, and updates the nodes list and the nodes_by_type dictionary. The + id is a tuple `(node_type_name, local_id)`. If the node is of type + :class:`NodeType.OPERATION`, it also updates ``nodes_by_job`` and + ``nodes_by_machine`` based on the operation's job id and machine ids. Args: node_for_adding: @@ -145,25 +219,57 @@ def add_node(self, node_for_adding: Node) -> None: should be done exclusively through this method to avoid inconsistencies. """ - node_for_adding.node_id = self._next_node_id - self.graph.add_node( - node_for_adding.node_id, **{NODE_ATTR: node_for_adding} - ) + # Changed: Node ID generation logic + node_type_name = node_for_adding.node_type.name.lower() + local_id = len(self._nodes_by_type[node_for_adding.node_type]) + new_id = (node_type_name, local_id) + + node_for_adding.node_id = new_id self._nodes_by_type[node_for_adding.node_type].append(node_for_adding) self._nodes.append(node_for_adding) - self._next_node_id += 1 - self.removed_nodes.append(False) + self._nodes_map[new_id] = node_for_adding if node_for_adding.node_type == NodeType.OPERATION: operation = node_for_adding.operation self._nodes_by_job[operation.job_id].append(node_for_adding) for machine_id in operation.machines: self._nodes_by_machine[machine_id].append(node_for_adding) + self.instance_id_map[NodeType.OPERATION.name.lower()][ + operation.operation_id + ] = node_for_adding + if NodeType.OPERATION.name.lower() not in self.removed_nodes: + self.removed_nodes[NodeType.OPERATION.name.lower()] = [ + False + ] * self.instance.num_operations + elif node_for_adding.node_type == NodeType.MACHINE: + self.instance_id_map[NodeType.MACHINE.name.lower()][ + node_for_adding.machine_id + ] = node_for_adding + if NodeType.MACHINE.name.lower() not in self.removed_nodes: + self.removed_nodes[NodeType.MACHINE.name.lower()] = [ + False + ] * self.instance.num_machines + elif node_for_adding.node_type == NodeType.JOB: + self.instance_id_map[NodeType.JOB.name.lower()][ + node_for_adding.job_id + ] = node_for_adding + if NodeType.JOB.name.lower() not in self.removed_nodes: + self.removed_nodes[NodeType.JOB.name.lower()] = [ + False + ] * self.instance.num_jobs + else: + # For other node types, we can use a default id of 0 + self.instance_id_map[node_type_name][0] = node_for_adding + self.removed_nodes[node_type_name].append(False) + + # Initialize adjacency lists for the new node + self.adjacency_in[node_for_adding] = collections.defaultdict(list) + self.adjacency_out[node_for_adding] = collections.defaultdict(list) def add_edge( self, - u_of_edge: Node | int, - v_of_edge: Node | int, + u_of_edge: Node, + v_of_edge: Node, **attr, ) -> None: r"""Adds an edge to the graph. @@ -171,13 +277,13 @@ def add_edge( It automatically determines the edge type based on the source and destination nodes unless explicitly provided in the ``attr`` argument via the ``type`` key. The edge type is a tuple of strings: - ``(source_node_type, "to", destination_node_type)``. + ``(source_node_type, "to", destination_node_type)``. If edges of + type "conjunctive" or "disjunctive" are being added, the "to" + component of the edge type will be replaced accordingly. Args: u_of_edge: - The source node of the edge. If it is a :class:`Node`, its - ``node_id`` is used as the source. Otherwise, it is assumed to - be the ``node_id`` of the source. + The source node of the edge. Can be a :class:`Node` v_of_edge: The destination node of the edge. If it is a :class:`Node`, its ``node_id`` is used as the destination. Otherwise, it @@ -189,56 +295,221 @@ def add_edge( ValidationError: If ``u_of_edge`` or ``v_of_edge`` are not in the graph. """ - if isinstance(u_of_edge, Node): - u_of_edge = u_of_edge.node_id - if isinstance(v_of_edge, Node): - v_of_edge = v_of_edge.node_id - if u_of_edge not in self.graph or v_of_edge not in self.graph: + + # Ensure both nodes are in the graph + if u_of_edge not in self._nodes or v_of_edge not in self._nodes: raise ValidationError( "`u_of_edge` and `v_of_edge` must be in the graph." ) edge_type = attr.pop("type", None) + self.edge_types.add((u_of_edge.node_id[0], "to", v_of_edge.node_id[0])) if edge_type is None: - u_node = self.nodes[u_of_edge] - v_node = self.nodes[v_of_edge] - edge_type = ( - u_node.node_type.name.lower(), - "to", - v_node.node_type.name.lower(), + edge_type = (u_of_edge.node_id[0], "to", v_of_edge.node_id[0]) + self.edge_types.add(edge_type) + else: + new_edge_type = ( + u_of_edge.node_id[0], + edge_type.name, + v_of_edge.node_id[0], ) - self.graph.add_edge(u_of_edge, v_of_edge, type=edge_type, **attr) + edge_type = new_edge_type + self.edge_types.add(new_edge_type) + self.adjacency_in[v_of_edge][edge_type].append(u_of_edge) + self.adjacency_out[u_of_edge][edge_type].append(v_of_edge) - def remove_node(self, node_id: int) -> None: - """Removes a node from the graph and the isolated nodes that result - from the removal. + def remove_node(self, node_id: tuple[str, int]) -> None: + """ + Removes a node and its edges from the graph, then renumbers the + local IDs of subsequent nodes of the same type to maintain a + contiguous sequence. + + This is a complex operation that involves modifying node IDs in-place + and ensuring all graph data structures remain consistent. + """ + node_type_name, local_id = node_id + + # 1. Verify the node exists before proceeding. + if node_id not in self._nodes_map: + return + + node_to_remove = self._nodes_map[node_id] + + if node_to_remove.node_type == NodeType.OPERATION: + operation = node_to_remove.operation + self.removed_nodes[NodeType.OPERATION.name.lower()][ + operation.operation_id + ] = True + elif node_to_remove.node_type == NodeType.MACHINE: + self.removed_nodes[NodeType.MACHINE.name.lower()][ + node_to_remove.machine_id + ] = True + elif node_to_remove.node_type == NodeType.JOB: + self.removed_nodes[NodeType.JOB.name.lower()][ + node_to_remove.job_id + ] = True + else: + # For other node types, we can use a default id of 0 + self.removed_nodes[node_type_name][0] = True + + # 2. Remove all edges connected to the node from the adjacency lists. + # Update neighbors that have incoming edges from the node. + if node_to_remove in self.adjacency_out: + for edge_type, neighbors in self.adjacency_out[ + node_to_remove + ].items(): + for neighbor in list(neighbors): + if ( + neighbor in self.adjacency_in + and edge_type in self.adjacency_in[neighbor] + ): + self.adjacency_in[neighbor][edge_type].remove( + node_to_remove + ) + + # Update neighbors that have outgoing edges to the node to be removed. + if node_to_remove in self.adjacency_in: + for edge_type, neighbors in self.adjacency_in[ + node_to_remove + ].items(): + for neighbor in list(neighbors): + if ( + neighbor in self.adjacency_out + and edge_type in self.adjacency_out[neighbor] + ): + self.adjacency_out[neighbor][edge_type].remove( + node_to_remove + ) + + # Remove the node itself from the adjacency lists. + self.adjacency_out.pop(node_to_remove, None) + self.adjacency_in.pop(node_to_remove, None) + + # 3. Physically remove the node from all tracking lists and maps. + self._nodes.remove(node_to_remove) + self._nodes_by_type[node_to_remove.node_type].remove(node_to_remove) + del self._nodes_map[node_id] + + if node_to_remove.node_type == NodeType.OPERATION: + operation = node_to_remove.operation + self._nodes_by_job[operation.job_id].remove(node_to_remove) + for machine_id in operation.machines: + self._nodes_by_machine[machine_id].remove(node_to_remove) + + # 4. Identify and renumber all subsequent nodes of the same type. + # Collect nodes that need re-numbering. + nodes_to_renumber = [ + node + for node in self._nodes_by_type[node_to_remove.node_type] + if node.node_id[1] > local_id + ] + # Sort them by their current ID to process in the correct order. + nodes_to_renumber.sort(key=lambda n: n.node_id[1]) + + for node_to_update in nodes_to_renumber: + old_node_id = node_to_update.node_id + new_node_id = (old_node_id[0], old_node_id[1] - 1) + + # Remove the node from the adjacency lists. + in_edges = self.adjacency_in.pop(node_to_update, None) + out_edges = self.adjacency_out.pop(node_to_update, None) + + # Mutate the node's ID. + node_to_update.node_id = new_node_id + + # Re-insert adjacency data with the updated node object as the key. + if in_edges: + self.adjacency_in[node_to_update] = in_edges + if out_edges: + self.adjacency_out[node_to_update] = out_edges + + # Update the main node map to reflect the new ID. + del self._nodes_map[old_node_id] + self._nodes_map[new_node_id] = node_to_update + + def remove_edge( + self, + u_of_edge: Node, + v_of_edge: Node, + **attr, + ) -> None: + r"""Removes an edge from the graph. + + This method removes the edge between two nodes, updating the adjacency + lists accordingly. It also checks if the nodes are in the graph before + attempting to remove the edge. Args: - node_id: - The id of the node to remove. + u_of_edge: + The source node of the edge. + v_of_edge: + The destination node of the edge. + \**attr: + Additional attributes to identify the edge type. """ - self.graph.remove_node(node_id) - self.removed_nodes[node_id] = True + edge_type = attr.pop("type", None) + if edge_type is None: + edge_type = (u_of_edge.node_id[0], "to", v_of_edge.node_id[0]) + + # check if any of the nodes has been removed + if self.is_removed(u_of_edge) or self.is_removed(v_of_edge): + return + + # Remove from adjacency lists + if u_of_edge in self.adjacency_out: + self.adjacency_out[u_of_edge][edge_type].remove(v_of_edge) + if v_of_edge in self.adjacency_in: + self.adjacency_in[v_of_edge][edge_type].remove(u_of_edge) def remove_isolated_nodes(self) -> None: """Removes isolated nodes from the graph.""" - isolated_nodes = list(nx.isolates(self.graph)) - for isolated_node in isolated_nodes: - self.removed_nodes[isolated_node] = True - self.graph.remove_nodes_from(isolated_nodes) - - def is_removed(self, node: int | Node) -> bool: + # get isolated nodes, meaning nodes with no incoming or outgoing edges, + # for all edge types, meaning only empty lists + isolated_nodes = list() + for node in self._nodes: + cond1 = False + cond2 = False + if node in self.adjacency_in: + cond1 = all( + not neighbors + for neighbors in self.adjacency_in[node].values() + ) + if node in self.adjacency_out: + cond2 = all( + not neighbors + for neighbors in self.adjacency_out[node].values() + ) + if cond1 and cond2: + isolated_nodes.append(node) + + # Remove isolated nodes + for node in isolated_nodes: + self.remove_node(node.node_id) + + def is_removed(self, node: Node) -> bool: """Returns whether the node is removed from the graph. Args: node: - The node to check. If it is a ``Node``, its `node_id` is used - as the node to check. Otherwise, it is assumed to be the - ``node_id`` of the node to check. + The node to check. Can be a :class:`Node`. """ - if isinstance(node, Node): - node = node.node_id - return self.removed_nodes[node] + + if node.node_type.name == NodeType.OPERATION.name: + return self.removed_nodes[NodeType.OPERATION.name.lower()][ + node.operation.operation_id + ] + if node.node_type.name == NodeType.MACHINE.name: + return self.removed_nodes[NodeType.MACHINE.name.lower()][ + node.machine_id + ] + if node.node_type.name == NodeType.JOB.name: + return self.removed_nodes[NodeType.JOB.name.lower()][node.job_id] + # Default case for other node types + return ( + self.removed_nodes[node.node_type.name.lower()][0] + if node.node_type.name.lower() in self.removed_nodes + else False + ) def non_removed_nodes(self) -> list[Node]: """Returns the nodes that are not removed from the graph.""" @@ -299,21 +570,29 @@ def get_node_by_type_and_id( The node with the given id. """ - def get_nested_attr(obj, attr_path: str): - """Helper function to get nested attribute.""" - attrs = attr_path.split(".") - for attr in attrs: - obj = getattr(obj, attr) - return obj - - nodes = self._nodes_by_type[node_type] - if node_id < len(nodes): - node = nodes[node_id] - if get_nested_attr(node, id_attr) == node_id: - return node - - for node in nodes: - if get_nested_attr(node, id_attr) == node_id: - return node + nodes = self.instance_id_map[node_type.name.lower()] + if node_id in nodes: + return nodes[node_id] raise ValidationError(f"No node found with node.{id_attr}={node_id}") + + @property + def edge_index_dict(self) -> dict[tuple[str, str, str], np.ndarray]: + """Returns the edge index as a dictionary of numpy arrays. + The keys are edge types, and the values are numpy arrays of shape + (2, num_edges) representing the edges of that type. + """ + edge_index: DefaultDict[tuple[str, str, str], np.ndarray] = ( + collections.defaultdict(lambda: np.empty((2, 0), np.int32)) + ) + for node, edges in self.adjacency_out.items(): + src = node.node_id[1] + for edge_type, neighbors in edges.items(): + if len(neighbors) == 0: + continue + dst = np.array( + [[src, neighbor.node_id[1]] for neighbor in neighbors], + dtype=np.int32, + ).T + edge_index[edge_type] = np.hstack((edge_index[edge_type], dst)) + return dict(edge_index) diff --git a/job_shop_lib/graphs/_node.py b/job_shop_lib/graphs/_node.py index 42a6676e..579120b9 100644 --- a/job_shop_lib/graphs/_node.py +++ b/job_shop_lib/graphs/_node.py @@ -13,7 +13,9 @@ class Node: A node is hashable by its id. The id is assigned when the node is added to the graph. The id must be unique for each node in the graph, and should be - used to identify the node in the networkx graph. + used to identify the node in the networkx graph. The id is a tuple + containing the node type's name as a string and a local integer id, + e.g., ``("MACHINE", 42)``. Depending on the type of the node, it can have different attributes. The following table shows the attributes of each type of node: @@ -29,7 +31,7 @@ class Node: +----------------+---------------------+ In terms of equality, two nodes are equal if they have the same id. - Additionally, one node is equal to an integer if the integer is equal to + Additionally, a node is equal to a tuple if the tuple is equal to its id. It is also hashable by its id. This allows for using the node as a key in a dictionary, at the same time @@ -38,10 +40,10 @@ class Node: .. code-block:: python node = Node(NodeType.SOURCE) - node.node_id = 1 + node.node_id = ("SOURCE", 0) graph = {node: "some value"} print(graph[node]) # "some value" - print(graph[1]) # "some value" + print(graph[("SOURCE", 0)]) # "some value" Args: node_type: @@ -67,7 +69,7 @@ class Node: __slots__ = { "node_type": "The type of the node.", - "_node_id": "Unique identifier for the node.", + "_node_id": "Unique identifier for the node (tuple[str, int]).", "_operation": ("The operation associated with the node."), "_machine_id": ("The machine ID associated with the node."), "_job_id": "The job ID associated with the node.", @@ -90,14 +92,14 @@ def __init__( raise ValidationError("Job node must have a job_id.") self.node_type: NodeType = node_type - self._node_id: int | None = None + self._node_id: tuple[str, int] | None = None self._operation = operation self._machine_id = machine_id self._job_id = job_id @property - def node_id(self) -> int: + def node_id(self) -> tuple[str, int]: """Returns a unique identifier for the node.""" if self._node_id is None: raise UninitializedAttributeError( @@ -106,7 +108,7 @@ def node_id(self) -> int: return self._node_id @node_id.setter - def node_id(self, value: int) -> None: + def node_id(self, value: tuple[str, int]) -> None: self._node_id = value @property @@ -150,27 +152,33 @@ def job_id(self) -> int: return self._job_id def __hash__(self) -> int: - return self.node_id + return hash(self.node_id) def __eq__(self, __value: object) -> bool: if isinstance(__value, Node): - __value = __value.node_id + return self.node_id == __value.node_id return self.node_id == __value def __repr__(self) -> str: + # Use self.node_id to trigger UninitializedAttributeError if not set + try: + node_id_repr = f"id={self.node_id}" + except UninitializedAttributeError: + node_id_repr = "id=None" + if self.node_type == NodeType.OPERATION: return ( - f"Node(node_type={self.node_type.name}, id={self._node_id}, " + f"Node(node_type={self.node_type.name}, {node_id_repr}, " f"operation={self.operation})" ) if self.node_type == NodeType.MACHINE: return ( - f"Node(node_type={self.node_type.name}, id={self._node_id}, " + f"Node(node_type={self.node_type.name}, {node_id_repr}, " f"machine_id={self._machine_id})" ) if self.node_type == NodeType.JOB: return ( - f"Node(node_type={self.node_type.name}, id={self._node_id}, " + f"Node(node_type={self.node_type.name}, {node_id_repr}, " f"job_id={self._job_id})" ) - return f"Node(node_type={self.node_type.name}, id={self._node_id})" + return f"Node(node_type={self.node_type.name}, {node_id_repr})" diff --git a/job_shop_lib/graphs/graph_updaters/_disjunctive_graph_updater.py b/job_shop_lib/graphs/graph_updaters/_disjunctive_graph_updater.py index a68b045a..8abc3563 100644 --- a/job_shop_lib/graphs/graph_updaters/_disjunctive_graph_updater.py +++ b/job_shop_lib/graphs/graph_updaters/_disjunctive_graph_updater.py @@ -64,9 +64,9 @@ def update(self, scheduled_operation: ScheduledOperation) -> None: # Remove the disjunctive arcs between the scheduled operation and the # previous operation - scheduled_operation_node = self.job_shop_graph.nodes[ + scheduled_operation_node = self.job_shop_graph.get_operation_node( scheduled_operation.operation.operation_id - ] + ) if ( scheduled_operation_node.operation is not scheduled_operation.operation @@ -77,14 +77,22 @@ def update(self, scheduled_operation: ScheduledOperation) -> None: "added to the graph. This method assumes that the operation id" " and node id are the same." ) - scheduled_id = scheduled_operation_node.node_id - assert scheduled_id == scheduled_operation.operation.operation_id + scheduled_operation_node_op_id = ( + scheduled_operation_node.operation.operation_id + ) + assert ( + scheduled_operation_node_op_id + == scheduled_operation.operation.operation_id + ) previous_id = previous_scheduled_operation.operation.operation_id + previous_node = self.job_shop_graph.get_operation_node(previous_id) if self.job_shop_graph.is_removed( - previous_id - ) or self.job_shop_graph.is_removed(scheduled_id): + previous_node + ) or self.job_shop_graph.is_removed(scheduled_operation_node): return - self.job_shop_graph.graph.remove_edge(scheduled_id, previous_id) + self.job_shop_graph.remove_edge( + scheduled_operation_node, previous_node + ) # Now, remove all the disjunctive edges between the previous scheduled # operation and the other operations in the machine schedule @@ -100,9 +108,8 @@ def update(self, scheduled_operation: ScheduledOperation) -> None: for operation in operations_with_same_machine: if operation.operation_id in already_scheduled_operations: continue - self.job_shop_graph.graph.remove_edge( - previous_id, operation.operation_id - ) - self.job_shop_graph.graph.remove_edge( - operation.operation_id, previous_id + operation_node = self.job_shop_graph.get_operation_node( + operation.operation_id ) + self.job_shop_graph.remove_edge(previous_node, operation_node) + self.job_shop_graph.remove_edge(operation_node, previous_node) diff --git a/job_shop_lib/graphs/graph_updaters/_utils.py b/job_shop_lib/graphs/graph_updaters/_utils.py index 13332869..3916de4c 100644 --- a/job_shop_lib/graphs/graph_updaters/_utils.py +++ b/job_shop_lib/graphs/graph_updaters/_utils.py @@ -3,7 +3,7 @@ from collections.abc import Iterable from job_shop_lib import Operation -from job_shop_lib.graphs import JobShopGraph +from job_shop_lib.graphs import JobShopGraph, NodeType def remove_completed_operations( @@ -19,7 +19,11 @@ def remove_completed_operations( The dispatcher instance. """ for operation in completed_operations: - node_id = operation.operation_id - if job_shop_graph.removed_nodes[node_id]: + if job_shop_graph.removed_nodes[NodeType.OPERATION.name.lower()][ + operation.operation_id + ]: continue + node_id = job_shop_graph.get_operation_node( + operation.operation_id + ).node_id job_shop_graph.remove_node(node_id) diff --git a/job_shop_lib/reinforcement_learning/.DS_Store b/job_shop_lib/reinforcement_learning/.DS_Store new file mode 100644 index 00000000..7b212b7c Binary files /dev/null and b/job_shop_lib/reinforcement_learning/.DS_Store differ diff --git a/job_shop_lib/reinforcement_learning/__init__.py b/job_shop_lib/reinforcement_learning/__init__.py index c6841d75..5319b9e4 100644 --- a/job_shop_lib/reinforcement_learning/__init__.py +++ b/job_shop_lib/reinforcement_learning/__init__.py @@ -53,19 +53,12 @@ from job_shop_lib.reinforcement_learning._multi_job_shop_graph_env import ( MultiJobShopGraphEnv, ) -from ._resource_task_graph_observation import ( - ResourceTaskGraphObservation, - ResourceTaskGraphObservationDict, -) - __all__ = [ "SingleJobShopGraphEnv", "MultiJobShopGraphEnv", "ObservationDict", "ObservationSpaceKey", - "ResourceTaskGraphObservation", - "ResourceTaskGraphObservationDict", "RewardObserver", "MakespanReward", "IdleTimeReward", diff --git a/job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py b/job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py index 19dc8c8f..513530b6 100644 --- a/job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py +++ b/job_shop_lib/reinforcement_learning/_multi_job_shop_graph_env.py @@ -1,9 +1,9 @@ """Home of the `GraphEnvironment` class.""" -from collections import defaultdict from collections.abc import Callable, Sequence from typing import Any from copy import deepcopy +from numpy.typing import NDArray import gymnasium as gym import numpy as np @@ -27,8 +27,6 @@ RenderConfig, MakespanReward, ObservationDict, - ObservationSpaceKey, - add_padding, ) @@ -145,13 +143,6 @@ class MultiJobShopGraphEnv(gym.Env): render_config: Configuration for rendering. See :class:`~job_shop_lib.RenderConfig`. - - use_padding: - Whether to use padding in observations. If True, all matrices - are padded to fixed sizes based on the maximum instance size. - Values are padded with -1, except for the "removed_nodes" key, - which is padded with ``True``, indicating that the node is - removed. """ def __init__( @@ -172,7 +163,6 @@ def __init__( ] = DispatcherObserverConfig(class_type=MakespanReward), render_mode: str | None = None, render_config: RenderConfig | None = None, - use_padding: bool = True, ) -> None: super().__init__() @@ -191,7 +181,6 @@ def __init__( ready_operations_filter=ready_operations_filter, render_mode=render_mode, render_config=render_config, - use_padding=use_padding, ) self.instance_generator = instance_generator self.graph_initializer = graph_initializer @@ -244,16 +233,6 @@ def ready_operations_filter( ready_operations_filter ) - @property - def use_padding(self) -> bool: - """Returns whether the padding is used.""" - return self.single_job_shop_graph_env.use_padding - - @use_padding.setter - def use_padding(self, use_padding: bool) -> None: - """Sets whether the padding is used.""" - self.single_job_shop_graph_env.use_padding = use_padding - @property def job_shop_graph(self) -> JobShopGraph: """Returns the current job shop graph.""" @@ -293,13 +272,13 @@ def reset( ready_operations_filter=self.ready_operations_filter, render_mode=self.render_mode, render_config=self.render_config, - use_padding=self.single_job_shop_graph_env.use_padding, ) obs, info = self.single_job_shop_graph_env.reset( seed=seed, options=options ) - if self.use_padding: - obs = self._add_padding_to_observation(obs) + self.observation_space = deepcopy( + self.single_job_shop_graph_env.observation_space + ) return obs, info @@ -331,55 +310,33 @@ def step( obs, reward, done, truncated, info = ( self.single_job_shop_graph_env.step(action) ) - if self.use_padding: - obs = self._add_padding_to_observation(obs) return obs, reward, done, truncated, info - def _add_padding_to_observation( - self, observation: ObservationDict - ) -> ObservationDict: - """Adds padding to the observation. - - "removed_nodes": - input_shape: (num_nodes,) - output_shape: (max_num_nodes,) (padded with True) - "edge_index": - input_shape: (2, num_edges) - output_shape: (2, max_num_edges) (padded with -1) - "operations": - input_shape: (num_operations, num_features) - output_shape: (max_num_operations, num_features) (padded with -1) - "jobs": - input_shape: (num_jobs, num_features) - output_shape: (max_num_jobs, num_features) (padded with -1) - "machines": - input_shape: (num_machines, num_features) - output_shape: (max_num_machines, num_features) (padded with -1) - """ - padding_value: dict[str, float | bool] = defaultdict(lambda: -1) - padding_value[ObservationSpaceKey.REMOVED_NODES.value] = True - for key, value in observation.items(): - if not isinstance(value, np.ndarray): # Make mypy happy - continue - expected_shape = self._get_output_shape(key) - observation[key] = add_padding( # type: ignore[literal-required] - value, - expected_shape, - padding_value=padding_value[key], - ) - return observation - - def _get_output_shape(self, key: str) -> tuple[int, ...]: + def _get_output_shape( + self, key: str + ) -> tuple[int, ...] | dict[str, tuple[int, ...]]: """Returns the output shape of the observation space key.""" - output_shape = self.observation_space[key].shape - assert output_shape is not None # Make mypy happy - return output_shape + space = self.observation_space[key] + + if isinstance(space, gym.spaces.Box): + assert space.shape is not None # mypy + return space.shape + elif isinstance(space, gym.spaces.Dict): + shapes: dict[str, tuple[int, ...]] = {} + for k, v in space.spaces.items(): + assert v.shape is not None # mypy knows shape is tuple now + shapes[k] = v.shape + return shapes + else: + raise ValueError( + f"Unsupported space type for key {key}: {type(space)}" + ) def render(self) -> None: self.single_job_shop_graph_env.render() - def get_available_actions_with_ids(self) -> list[tuple[int, int, int]]: + def get_available_actions_with_ids(self) -> NDArray[np.int32]: """Returns a list of available actions in the form of (operation_id, machine_id, job_id).""" return self.single_job_shop_graph_env.get_available_actions_with_ids() diff --git a/job_shop_lib/reinforcement_learning/_resource_task_graph_observation.py b/job_shop_lib/reinforcement_learning/_resource_task_graph_observation.py deleted file mode 100644 index f8d28989..00000000 --- a/job_shop_lib/reinforcement_learning/_resource_task_graph_observation.py +++ /dev/null @@ -1,331 +0,0 @@ -"""Contains wrappers for the environments.""" - -from typing import TypeVar, TypedDict, Generic, Any -from gymnasium import ObservationWrapper -import numpy as np -from numpy.typing import NDArray - -from job_shop_lib.reinforcement_learning import ( - ObservationDict, - SingleJobShopGraphEnv, - MultiJobShopGraphEnv, - create_edge_type_dict, - map_values, -) -from job_shop_lib.graphs import NodeType -from job_shop_lib.dispatching.feature_observers import FeatureType - -T = TypeVar("T", bound=np.number) -EnvType = TypeVar( # pylint: disable=invalid-name - "EnvType", bound=SingleJobShopGraphEnv | MultiJobShopGraphEnv -) - -_NODE_TYPE_TO_FEATURE_TYPE = { - NodeType.OPERATION: FeatureType.OPERATIONS, - NodeType.MACHINE: FeatureType.MACHINES, - NodeType.JOB: FeatureType.JOBS, -} -_FEATURE_TYPE_STR_TO_NODE_TYPE = { - FeatureType.OPERATIONS.value: NodeType.OPERATION, - FeatureType.MACHINES.value: NodeType.MACHINE, - FeatureType.JOBS.value: NodeType.JOB, -} - - -class ResourceTaskGraphObservationDict(TypedDict): - """Represents a dictionary for resource task graph observations.""" - - edge_index_dict: dict[tuple[str, str, str], NDArray[np.int32]] - node_features_dict: dict[str, NDArray[np.float32]] - original_ids_dict: dict[str, NDArray[np.int32]] - - -# pylint: disable=line-too-long -class ResourceTaskGraphObservation(ObservationWrapper, Generic[EnvType]): - """Observation wrapper that converts an observation following the - :class:`ObservationDict` format to a format suitable to PyG's - [`HeteroData`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.HeteroData.html). - - In particular, the ``edge_index`` is converted into a ``edge_index_dict`` - with keys ``(node_type_i, "to", node_type_j)``. The ``node_type_i`` and - ``node_type_j`` are the node types of the source and target nodes, - respectively. - - Additionally, the node features are stored in a dictionary with keys - corresponding to the node type names under the ``node_features_dict`` key. - - The node IDs are mapped to local IDs starting from 0. The - ``original_ids_dict`` contains the original node IDs before removing nodes. - - Attributes: - global_to_local_id: A dictionary mapping global node IDs to local node - IDs for each node type. - type_ranges: A dictionary mapping node type names to (start, end) index - ranges. - - Args: - env: The environment to wrap. - """ - - def __init__(self, env: EnvType): - super().__init__(env) - self.env = env # Unnecessary, but makes mypy happy - self.global_to_local_id = self._compute_id_mappings() - self.type_ranges = self._compute_node_type_ranges() - self._start_from_zero_mapping: dict[str, dict[int, int]] = {} - - def step(self, action: tuple[int, int]): - """Takes a step in the environment. - - Args: - action: - The action to take. The action is a tuple of two integers - (job_id, machine_id): - the job ID and the machine ID in which to schedule the - operation. - - Returns: - A tuple containing the following elements: - - - The observation of the environment. - - The reward obtained. - - Whether the environment is done. - - Whether the episode was truncated (always False). - - A dictionary with additional information. The dictionary - contains the following keys: "feature_names", the names of the - features in the observation; and "available_operations_with_ids", - a list of available actions in the form of (operation_id, - machine_id, job_id). - """ - observation, reward, done, truncated, info = self.env.step(action) - new_observation = self.observation(observation) - new_info = self._info(info) - return new_observation, reward, done, truncated, new_info - - def reset(self, *, seed: int | None = None, options: dict | None = None): - """Resets the environment. - - Args: - seed: - Added to match the signature of the parent class. It is not - used in this method. - options: - Additional options to pass to the environment. Not used in - this method. - - Returns: - A tuple containing the following elements: - - - The observation of the environment. - - A dictionary with additional information, keys - include: "feature_names", the names of the features in the - observation; and "available_operations_with_ids", a list of - available a list of available actions in the form of - (operation_id, machine_id, job_id). - """ - observation, info = self.env.reset() - new_observation = self.observation(observation) - new_info = self._info(info) - return new_observation, new_info - - def _info(self, info: dict[str, Any]) -> dict[str, Any]: - """Updates the "available_operations_with_ids" key in the info - dictionary so that they start from 0 using the - `_start_from_zero_mapping` attribute. - """ - new_available_operations_ids = [] - for operation_id, machine_id, job_id in info[ - "available_operations_with_ids" - ]: - if "operation" in self._start_from_zero_mapping: - operation_id = self._start_from_zero_mapping["operation"][ - operation_id - ] - if "machine" in self._start_from_zero_mapping: - machine_id = self._start_from_zero_mapping["machine"][ - machine_id - ] - if "job" in self._start_from_zero_mapping: - job_id = self._start_from_zero_mapping["job"][job_id] - new_available_operations_ids.append( - (operation_id, machine_id, job_id) - ) - info["available_operations_with_ids"] = new_available_operations_ids - return info - - def _compute_id_mappings(self) -> dict[int, int]: - """Computes mappings from global node IDs to type-local IDs. - - Returns: - A dictionary mapping global node IDs to local node IDs for each - node type. - """ - mappings = {} - for node_type in NodeType: - type_nodes = self.unwrapped.job_shop_graph.nodes_by_type[node_type] - if not type_nodes: - continue - # Create mapping from global ID to local ID - # (0 to len(type_nodes)-1) - type_mapping = { - node.node_id: local_id - for local_id, node in enumerate(type_nodes) - } - mappings.update(type_mapping) - - return mappings - - def _compute_node_type_ranges(self) -> dict[str, tuple[int, int]]: - """Computes index ranges for each node type. - - Returns: - Dictionary mapping node type names to (start, end) index ranges - """ - type_ranges = {} - for node_type in NodeType: - type_nodes = self.unwrapped.job_shop_graph.nodes_by_type[node_type] - if not type_nodes: - continue - start = min(node.node_id for node in type_nodes) - end = max(node.node_id for node in type_nodes) + 1 - type_ranges[node_type.name.lower()] = (start, end) - - return type_ranges - - def observation( - self, observation: ObservationDict - ) -> ResourceTaskGraphObservationDict: - """Processes the observation data into the resource task graph format. - - Args: - observation: The observation dictionary. It must NOT have padding. - - Returns: - A dictionary containing the following keys: - - - "edge_index_dict": A dictionary mapping edge types to edge index - arrays. - - "node_features_dict": A dictionary mapping node type names to - node feature arrays. - - "original_ids_dict": A dictionary mapping node type names to the - original node IDs before removing nodes. - """ - edge_index_dict = create_edge_type_dict( - observation["edge_index"], - type_ranges=self.type_ranges, - relationship="to", - ) - node_features_dict = self._create_node_features_dict(observation) - node_features_dict, original_ids_dict = self._remove_nodes( - node_features_dict, observation["removed_nodes"] - ) - - # mapping from global node ID to local node ID - for key, edge_index in edge_index_dict.items(): - edge_index_dict[key] = map_values( - edge_index, self.global_to_local_id - ) - # mapping so that ids start from 0 in edge index - self._start_from_zero_mapping = self._get_start_from_zero_mappings( - original_ids_dict - ) - for (type_1, to, type_2), edge_index in edge_index_dict.items(): - edge_index_dict[(type_1, to, type_2)][0] = map_values( - edge_index[0], self._start_from_zero_mapping[type_1] - ) - edge_index_dict[(type_1, to, type_2)][1] = map_values( - edge_index[1], self._start_from_zero_mapping[type_2] - ) - - return { - "edge_index_dict": edge_index_dict, - "node_features_dict": node_features_dict, - "original_ids_dict": original_ids_dict, - } - - @staticmethod - def _get_start_from_zero_mappings( - original_indices_dict: dict[str, NDArray[np.int32]], - ) -> dict[str, dict[int, int]]: - mappings: dict[str, dict[int, int]] = {} - for key, indices in original_indices_dict.items(): - mappings[key] = { - idx: i for i, idx in enumerate(indices) # type: ignore[misc] - } # idx is an integer (false positive) - return mappings - - def _create_node_features_dict( - self, observation: ObservationDict - ) -> dict[str, NDArray]: - """Creates a dictionary of node features for each node type. - - Args: - observation: The observation dictionary. - - Returns: - Dictionary mapping node type names to node features. - """ - - node_features_dict = {} - for node_type, feature_type in _NODE_TYPE_TO_FEATURE_TYPE.items(): - if self.unwrapped.job_shop_graph.nodes_by_type[node_type]: - node_features_dict[feature_type.value] = observation[ - feature_type.value - ] - continue - if feature_type != FeatureType.JOBS: - continue - assert FeatureType.OPERATIONS.value in observation - job_features = observation[ - feature_type.value # type: ignore[literal-required] - ] - job_ids_of_ops = [ - node.operation.job_id - for node in self.unwrapped.job_shop_graph.nodes_by_type[ - NodeType.OPERATION - ] - ] - job_features_expanded = job_features[job_ids_of_ops] - operation_features = observation[FeatureType.OPERATIONS.value] - node_features_dict[FeatureType.OPERATIONS.value] = np.concatenate( - (operation_features, job_features_expanded), axis=1 - ) - return node_features_dict - - def _remove_nodes( - self, - node_features_dict: dict[str, NDArray[T]], - removed_nodes: NDArray[np.bool_], - ) -> tuple[dict[str, NDArray[T]], dict[str, NDArray[np.int32]]]: - """Removes nodes from the node features dictionary. - - Args: - node_features_dict: The node features dictionary. - - Returns: - The node features dictionary with the nodes removed and a - dictionary containing the original node ids. - """ - removed_nodes_dict: dict[str, NDArray[T]] = {} - original_ids_dict: dict[str, NDArray[np.int32]] = {} - for feature_type, features in node_features_dict.items(): - node_type = _FEATURE_TYPE_STR_TO_NODE_TYPE[ - feature_type - ].name.lower() - if node_type not in self.type_ranges: - continue - start, end = self.type_ranges[node_type] - removed_nodes_of_this_type = removed_nodes[start:end] - removed_nodes_dict[node_type] = features[ - ~removed_nodes_of_this_type - ] - original_ids_dict[node_type] = np.where( - ~removed_nodes_of_this_type # type: ignore[assignment] - )[0] - - return removed_nodes_dict, original_ids_dict - - @property - def unwrapped(self) -> EnvType: - """Returns the unwrapped environment.""" - return self.env # type: ignore[return-value] diff --git a/job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py b/job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py index ac397b5f..5b5ea983 100644 --- a/job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py +++ b/job_shop_lib/reinforcement_learning/_single_job_shop_graph_env.py @@ -12,12 +12,11 @@ from numpy.typing import NDArray from job_shop_lib import JobShopInstance, Operation -from job_shop_lib.graphs import JobShopGraph +from job_shop_lib.graphs import JobShopGraph, NodeType from job_shop_lib.graphs.graph_updaters import ( GraphUpdater, ResidualGraphUpdater, ) -from job_shop_lib.exceptions import ValidationError from job_shop_lib.dispatching import ( Dispatcher, filter_dominated_operations, @@ -28,16 +27,24 @@ CompositeFeatureObserver, FeatureObserver, FeatureObserverType, + FeatureType, ) from job_shop_lib.visualization.gantt import GanttChartCreator from job_shop_lib.reinforcement_learning import ( RewardObserver, MakespanReward, - add_padding, RenderConfig, ObservationSpaceKey, ObservationDict, ) +from job_shop_lib.exceptions import ValidationError + + +_FEATURE_TYPE_STR_TO_NODE_TYPE = { + FeatureType.OPERATIONS.value: NodeType.OPERATION, + FeatureType.MACHINES.value: NodeType.MACHINE, + FeatureType.JOBS.value: NodeType.JOB, +} class SingleJobShopGraphEnv(gym.Env): @@ -51,13 +58,15 @@ class SingleJobShopGraphEnv(gym.Env): Observation Space: A dictionary with the following keys: - - "removed_nodes": Binary vector indicating removed graph nodes. - - "edge_list": Matrix of graph edges in COO format. - - Feature matrices: Keys corresponding to the composite observer - features (e.g., "operations", "jobs", "machines"). + - "edge_index_dict": Dictionary mapping edge types to + their COO format indices. + - "available_operations_with_ids": List of available actions + represented as (operation_id, machine_id, job_id) tuples. + - "node_features_dict": (optional) Dictionary mapping node types to + their feature matrices. Action Space: - MultiDiscrete space representing (job_id, machine_id) pairs. + MultiDiscrete space representing (operation_id, machine_id) pairs. Render Modes: @@ -85,17 +94,20 @@ class SingleJobShopGraphEnv(gym.Env): action_space: Defines the action space. The action is a tuple of two integers - (job_id, machine_id). The machine_id can be -1 if the selected - operation can only be scheduled in one machine. + (operation_id, machine_id). The machine_id can be -1 if the + selected operation can only be scheduled in one machine. observation_space: Defines the observation space. The observation is a dictionary with the following keys: - - "removed_nodes": Binary vector indicating removed graph nodes. - - "edge_list": Matrix of graph edges in COO format. - - Feature matrices: Keys corresponding to the composite observer - features (e.g., "operations", "jobs", "machines"). + - "edge_index_dict": Dictionary mapping edge types to + their COO format indices. + - "available_operations_with_ids": List of available + actions represented as + (operation_id, machine_id, job_id) tuples. + - "node_features_dict": (optional) Dictionary mapping + node types to their feature matrices. render_mode: The mode for rendering the environment ("human", "save_video", @@ -105,10 +117,6 @@ class SingleJobShopGraphEnv(gym.Env): Creates Gantt chart visualizations. See :class:`~job_shop_lib.visualization.GanttChartCreator`. - use_padding: - Whether to use padding in observations. Padding maintains the - observation space shape when the number of nodes changes. - Args: job_shop_graph: The JobShopGraph instance representing the job shop problem. @@ -127,9 +135,6 @@ class SingleJobShopGraphEnv(gym.Env): render_config: Configuration for rendering (e.g., paths for saving videos or GIFs). See :class:`~job_shop_lib.visualization.RenderConfig`. - use_padding: - Whether to use padding in observations. Padding maintains the - observation space shape when the number of nodes changes. """ metadata = {"render_modes": ["human", "save_video", "save_gif"]} @@ -158,7 +163,6 @@ def __init__( ) = filter_dominated_operations, render_mode: str | None = None, render_config: RenderConfig | None = None, - use_padding: bool = True, ) -> None: super().__init__() # Used for resetting the environment @@ -186,6 +190,7 @@ def __init__( self.action_space = gym.spaces.MultiDiscrete( [self.instance.num_jobs, self.instance.num_machines], start=[0, -1] ) + self.observation_space: gym.spaces.Dict = self._get_observation_space() self.render_mode = render_mode if render_config is None: @@ -193,7 +198,6 @@ def __init__( self.gantt_chart_creator = GanttChartCreator( dispatcher=self.dispatcher, **render_config ) - self.use_padding = use_padding @property def instance(self) -> JobShopInstance: @@ -240,29 +244,60 @@ def machine_utilization( # noqa: DOC201,DOC203 def _get_observation_space(self) -> gym.spaces.Dict: """Returns the observation space dictionary.""" - num_edges = self.job_shop_graph.num_edges - dict_space: dict[str, gym.Space] = { - ObservationSpaceKey.REMOVED_NODES.value: gym.spaces.MultiBinary( - len(self.job_shop_graph.nodes) - ), - ObservationSpaceKey.EDGE_INDEX.value: gym.spaces.MultiDiscrete( - np.full( - (2, num_edges), - fill_value=len(self.job_shop_graph.nodes) + 1, - dtype=np.int32, - ), - start=np.full( - (2, num_edges), - fill_value=-1, # -1 is used for padding + + obs_space = gym.spaces.Dict() + initial_edge_index_dict = self.initial_job_shop_graph.edge_index_dict + edge_index_space = gym.spaces.Dict( + { + key: gym.spaces.Box( # type: ignore + low=0, + high=np.iinfo(np.int32).max, + shape=edges.shape, dtype=np.int32, - ), - ), - } - for feature_type, matrix in self.composite_observer.features.items(): - dict_space[feature_type.value] = gym.spaces.Box( - low=-np.inf, high=np.inf, shape=matrix.shape + ) + for key, edges in initial_edge_index_dict.items() + } + ) + obs_space[ObservationSpaceKey.EDGE_INDEX] = edge_index_space + + num_available_actions = len(self.get_available_actions_with_ids()) + available_actions_with_ids_space = gym.spaces.Box( + low=np.full((num_available_actions, 3), -1, dtype=np.int32), + high=np.array( + [ + len(self.job_shop_graph.nodes_by_type[NodeType.OPERATION]) + - 1, + len(self.job_shop_graph.nodes_by_type[NodeType.MACHINE]) + - 1, + len(self.job_shop_graph.nodes_by_type[NodeType.JOB]) - 1, + ], + dtype=np.int32, ) - return gym.spaces.Dict(dict_space) + .reshape(1, 3) + .repeat(num_available_actions, axis=0), + shape=(num_available_actions, 3), + dtype=np.int32, + ) + obs_space[ObservationSpaceKey.ACTION_MASK] = ( + available_actions_with_ids_space + ) + if not self.composite_observer.features: + return obs_space + node_features_space = gym.spaces.Dict( + { + feature_type.value: gym.spaces.Box( + low=-np.inf, + high=np.inf, + shape=matrix.shape, + dtype=np.float32, + ) + for feature_type, matrix in + self.composite_observer.features.items() + } + ) + obs_space[ObservationSpaceKey.NODE_FEATURES] = node_features_space + + return obs_space def reset( self, @@ -293,12 +328,7 @@ def reset( super().reset(seed=seed, options=options) self.dispatcher.reset() obs = self.get_observation() - return obs, { - "feature_names": self.composite_observer.column_names, - "available_operations_with_ids": ( - self.get_available_actions_with_ids() - ), - } + return obs, {"feature_names": self.composite_observer.column_names} def step( self, action: tuple[int, int] @@ -308,8 +338,8 @@ def step( Args: action: The action to take. The action is a tuple of two integers - (job_id, machine_id): - the job ID and the machine ID in which to schedule the + (operation_id, machine_id): + the operation ID and the machine ID in which to schedule the operation. Returns: @@ -321,14 +351,18 @@ def step( - Whether the episode was truncated (always False). - A dictionary with additional information. The dictionary contains the following keys: "feature_names", the names of the - features in the observation; and "available_operations_with_ids", - a list of available actions in the form of (operation_id, - machine_id, job_id). + features in the observation. """ - job_id, machine_id = action - operation = self.dispatcher.next_operation(job_id) - if machine_id == -1: + node_operation_id, node_machine_id = action + operation = self.job_shop_graph.nodes_map[ + ("operation", node_operation_id) + ].operation + if node_machine_id == -1: machine_id = operation.machine_id + else: + machine_id = self.job_shop_graph.nodes_map[ + ("machine", node_machine_id) + ].machine_id self.dispatcher.dispatch(operation, machine_id) @@ -337,41 +371,39 @@ def step( done = self.dispatcher.schedule.is_complete() truncated = False info: dict[str, Any] = { - "feature_names": self.composite_observer.column_names, - "available_operations_with_ids": ( - self.get_available_actions_with_ids() - ), + "feature_names": self.composite_observer.column_names } return obs, reward, done, truncated, info def get_observation(self) -> ObservationDict: """Returns the current observation of the environment.""" + node_features_dict: dict[str, NDArray[np.float32]] = {} + removed_nodes = self.job_shop_graph.removed_nodes + + for feature_type, matrix in self.composite_observer.features.items(): + # Use the provided mapping for robust conversion + node_type = _FEATURE_TYPE_STR_TO_NODE_TYPE[feature_type.value] + removed_mask = removed_nodes.get(node_type.name.lower()) + + current_matrix = matrix + if removed_mask is not None: + # The mask is True for removed nodes; invert to get active ones + active_mask = ~np.array(removed_mask, dtype=bool) + current_matrix = matrix[active_mask] + + node_features_dict[feature_type.value] = current_matrix + + # Construct the final observation dictionary with the nested structure observation: ObservationDict = { - ObservationSpaceKey.REMOVED_NODES.value: np.array( - self.job_shop_graph.removed_nodes, dtype=bool - ), - ObservationSpaceKey.EDGE_INDEX.value: self._get_edge_index(), + ObservationSpaceKey.EDGE_INDEX: # type: ignore + self.job_shop_graph.edge_index_dict, + ObservationSpaceKey.ACTION_MASK: # type: ignore + self.get_available_actions_with_ids(), + ObservationSpaceKey.NODE_FEATURES: # type: ignore + node_features_dict, } - for feature_type, matrix in self.composite_observer.features.items(): - observation[feature_type.value] = matrix return observation - def _get_edge_index(self) -> NDArray[np.int32]: - """Returns the edge index matrix.""" - edge_index = np.array( - self.job_shop_graph.graph.edges(), dtype=np.int32 - ).T - - if self.use_padding: - output_shape = self.observation_space[ - ObservationSpaceKey.EDGE_INDEX.value - ].shape - assert output_shape is not None # For the type checker - edge_index = add_padding( - edge_index, output_shape=output_shape, dtype=np.int32 - ) - return edge_index - def render(self): """Renders the environment. @@ -391,19 +423,38 @@ def render(self): elif self.render_mode == "save_gif": self.gantt_chart_creator.create_gif() - def get_available_actions_with_ids(self) -> list[tuple[int, int, int]]: + def get_available_actions_with_ids(self) -> NDArray[np.int32]: """Returns a list of available actions in the form of (operation_id, machine_id, job_id).""" available_operations = self.dispatcher.available_operations() available_operations_with_ids = [] for operation in available_operations: - job_id = operation.job_id - operation_id = operation.operation_id - for machine_id in operation.machines: + # For now only local operation ids are obtained + # from the graph + # jobs or machine ids will not be included + # if not present in the graph + operation_id = self.job_shop_graph.get_operation_node( + operation.operation_id + ).node_id[1] + if len(self.job_shop_graph.nodes_by_type[NodeType.JOB]) > 0: + job_id = self.job_shop_graph.get_job_node( + operation.job_id + ).node_id[1] + else: + job_id = -1 # Use -1 to indicate job_id is not in the graph + if len(self.job_shop_graph.nodes_by_type[NodeType.MACHINE]) > 0: + for machine_id in operation.machines: + machine_id = self.job_shop_graph.get_machine_node( + machine_id + ).node_id[1] + available_operations_with_ids.append( + [operation_id, machine_id, job_id] + ) + else: available_operations_with_ids.append( - (operation_id, machine_id, job_id) + [operation_id, -1, job_id] ) - return available_operations_with_ids + return np.array(available_operations_with_ids, dtype=np.int32) def validate_action(self, action: tuple[int, int]) -> None: """Validates that the action is legal in the current state. diff --git a/job_shop_lib/reinforcement_learning/_types_and_constants.py b/job_shop_lib/reinforcement_learning/_types_and_constants.py index 71978f66..5b1e7840 100644 --- a/job_shop_lib/reinforcement_learning/_types_and_constants.py +++ b/job_shop_lib/reinforcement_learning/_types_and_constants.py @@ -7,7 +7,6 @@ import numpy as np from numpy.typing import NDArray -from job_shop_lib.dispatching.feature_observers import FeatureType from job_shop_lib.visualization.gantt import ( PartialGanttChartPlotterConfig, GifConfig, @@ -26,37 +25,56 @@ class RenderConfig(TypedDict, total=False): class ObservationSpaceKey(str, Enum): """Enumeration of the keys for the observation space dictionary.""" - REMOVED_NODES = "removed_nodes" - EDGE_INDEX = "edge_index" - OPERATIONS = FeatureType.OPERATIONS.value - JOBS = FeatureType.JOBS.value - MACHINES = FeatureType.MACHINES.value + EDGE_INDEX = "edge_index_dict" + NODE_FEATURES = "node_features_dict" + ACTION_MASK = "available_operations_with_ids" + + +# NEW: TypedDict for the nested node features dictionary. +class NodeFeaturesDict(TypedDict, total=False): + """A dictionary containing feature matrices for different node types. + + Keys correspond to FeatureType values (e.g., 'operations', 'jobs'). + Values are the corresponding feature matrices (num_nodes, num_features). + """ + + operations: NDArray[np.float32] + jobs: NDArray[np.float32] + machines: NDArray[np.float32] class _ObservationDictRequired(TypedDict): """Required fields for the observation dictionary.""" - removed_nodes: NDArray[np.bool_] - edge_index: NDArray[np.int32] + edge_index_dict: dict[tuple[str, str, str], NDArray[np.int32]] + available_operations_with_ids: list[tuple[int, int, int]] +# UPDATED: Now contains the nested dictionary for node features. class _ObservationDictOptional(TypedDict, total=False): """Optional fields for the observation dictionary.""" - operations: NDArray[np.float32] - jobs: NDArray[np.float32] - machines: NDArray[np.float32] + node_features_dict: NodeFeaturesDict +# UPDATED: Docstring now reflects the new nested structure. class ObservationDict(_ObservationDictRequired, _ObservationDictOptional): """A dictionary containing the observation of the environment. - Required fields: - removed_nodes: Binary vector indicating removed nodes. - edge_index: Edge list in COO format. + This dictionary represents a heterogenous graph structure. + Required fields: + edge_index_dict: A dictionary mapping edge types + (source_type, relation, destination_type) to their respective + edge index tensors in COO format. + available_operations_with_ids: A list of tuples representing the + available operations and their IDs, where each tuple is of the + form + (local_operation_node_id, + local_machine_node_id, + local_job_node_id) + if nodes of each type are present, else -1. Optional fields: - operations: Matrix of operation features. - jobs: Matrix of job features. - machines: Matrix of machine features. + node_features_dict: A dictionary mapping node type names (from + FeatureType) to their corresponding feature matrices. """ diff --git a/job_shop_lib/visualization/graphs/_plot_disjunctive_graph.py b/job_shop_lib/visualization/graphs/_plot_disjunctive_graph.py index 44a00b3a..428535a7 100644 --- a/job_shop_lib/visualization/graphs/_plot_disjunctive_graph.py +++ b/job_shop_lib/visualization/graphs/_plot_disjunctive_graph.py @@ -216,13 +216,13 @@ def plot_disjunctive_graph( graphviz_layout, prog="dot", args="-Grankdir=LR" ) - temp_graph = copy.deepcopy(job_shop_graph.graph) + temp_graph = copy.deepcopy(job_shop_graph.get_networkx_graph()) # Remove disjunctive edges to get a better layout temp_graph.remove_edges_from( [ (u, v) - for u, v, d in job_shop_graph.graph.edges(data=True) - if d["type"] == EdgeType.DISJUNCTIVE + for u, v, d in job_shop_graph.get_networkx_graph().edges(data=True) + if d["type"][1] == EdgeType.DISJUNCTIVE.name ] ) @@ -244,7 +244,7 @@ def plot_disjunctive_graph( cmap_func = matplotlib.colormaps.get_cmap(color_map) remaining_machines = job_shop_graph.instance.num_machines for operation_node in operation_nodes: - if job_shop_graph.is_removed(operation_node.node_id): + if job_shop_graph.is_removed(operation_node): continue machine_id = operation_node.operation.machine_id if machine_id not in machine_colors: @@ -258,12 +258,12 @@ def plot_disjunctive_graph( node_colors: list[Any] = [ _get_node_color(node) for node in job_shop_graph.nodes - if not job_shop_graph.is_removed(node.node_id) + if not job_shop_graph.is_removed(node) ] else: node_colors = [] for node in job_shop_graph.nodes: - if job_shop_graph.is_removed(node.node_id): + if job_shop_graph.is_removed(node): continue if node.node_type == NodeType.OPERATION: machine_id = node.operation.machine_id @@ -272,7 +272,7 @@ def plot_disjunctive_graph( node_colors.append(machine_colors[machine_id]) nx.draw_networkx_nodes( - job_shop_graph.graph, + job_shop_graph.get_networkx_graph(), pos, node_size=node_size, node_color=node_colors, @@ -284,13 +284,13 @@ def plot_disjunctive_graph( # ---------- conjunctive_edges = [ (u, v) - for u, v, d in job_shop_graph.graph.edges(data=True) - if d["type"] == EdgeType.CONJUNCTIVE + for u, v, d in job_shop_graph.get_networkx_graph().edges(data=True) + if d["type"][1] == EdgeType.CONJUNCTIVE.name ] disjunctive_edges: Iterable[tuple[int, int]] = [ (u, v) - for u, v, d in job_shop_graph.graph.edges(data=True) - if d["type"] == EdgeType.DISJUNCTIVE + for u, v, d in job_shop_graph.get_networkx_graph().edges(data=True) + if d["type"][1] == EdgeType.DISJUNCTIVE.name ] if conjunctive_edges_additional_params is None: conjunctive_edges_additional_params = {} @@ -298,7 +298,7 @@ def plot_disjunctive_graph( disjunctive_edges_additional_params = {} nx.draw_networkx_edges( - job_shop_graph.graph, + job_shop_graph.get_networkx_graph(), pos, edgelist=conjunctive_edges, width=edge_width, @@ -317,7 +317,7 @@ def plot_disjunctive_graph( disjunctive_edges_filtered.add((u, v)) disjunctive_edges = disjunctive_edges_filtered nx.draw_networkx_edges( - job_shop_graph.graph, + job_shop_graph.get_networkx_graph(), pos, edgelist=disjunctive_edges, width=edge_width, @@ -331,20 +331,20 @@ def plot_disjunctive_graph( labels = {} if job_shop_graph.nodes_by_type[NodeType.SOURCE]: source_node = job_shop_graph.nodes_by_type[NodeType.SOURCE][0] - if not job_shop_graph.is_removed(source_node.node_id): + if not job_shop_graph.is_removed(source_node): labels[source_node] = start_node_label if job_shop_graph.nodes_by_type[NodeType.SINK]: sink_node = job_shop_graph.nodes_by_type[NodeType.SINK][0] # check if the sink node is removed - if not job_shop_graph.is_removed(sink_node.node_id): + if not job_shop_graph.is_removed(sink_node): labels[sink_node] = end_node_label for operation_node in operation_nodes: - if job_shop_graph.is_removed(operation_node.node_id): + if job_shop_graph.is_removed(operation_node): continue labels[operation_node] = operation_node_labeler(operation_node) nx.draw_networkx_labels( - job_shop_graph.graph, + job_shop_graph.get_networkx_graph(), pos, labels=labels, font_color=node_font_color, diff --git a/job_shop_lib/visualization/graphs/_plot_resource_task_graph.py b/job_shop_lib/visualization/graphs/_plot_resource_task_graph.py index e65b6b97..08ee6f1b 100644 --- a/job_shop_lib/visualization/graphs/_plot_resource_task_graph.py +++ b/job_shop_lib/visualization/graphs/_plot_resource_task_graph.py @@ -94,7 +94,7 @@ def plot_resource_task_graph( fig.suptitle(title) # Create the networkx graph - graph = job_shop_graph.graph + graph = job_shop_graph.get_networkx_graph() nodes = job_shop_graph.non_removed_nodes() # Create the layout if it was not provided @@ -114,9 +114,9 @@ def plot_resource_task_graph( if node_color_map is None else node_color_map ) - node_colors = [ - node_color_map(node) for node in job_shop_graph.nodes - ] # We need to get the color of all nodes to avoid an index error + node_colors = { + node.node_id: node_color_map(node) for node in job_shop_graph.nodes + } # We need to get the color of all nodes to avoid an index error if node_shapes is None: node_shapes = { "machine": "s", diff --git a/pyproject.toml b/pyproject.toml index 0bd7dbc9..6e6b3805 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "job-shop-lib" -version = "1.6.1" +version = "1.7.0" description = "An easy-to-use and modular Python library for the Job Shop Scheduling Problem (JSSP)" authors = ["Pabloo22 "] license = "MIT" diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 00000000..bb5c45f7 Binary files /dev/null and b/tests/.DS_Store differ diff --git a/tests/conftest.py b/tests/conftest.py index c65c1848..94eb30d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ -import pytest import random +import pytest + from job_shop_lib import ( JobShopInstance, Operation, @@ -24,7 +25,10 @@ build_resource_task_graph, JobShopGraph, ) -from job_shop_lib.benchmarking import load_benchmark_instance +from job_shop_lib.benchmarking import ( + load_benchmark_instance, + load_all_benchmark_instances, +) @pytest.fixture(name="job_shop_instance") @@ -165,7 +169,11 @@ def single_job_shop_graph_env_ft06() -> SingleJobShopGraphEnv: DispatcherObserverConfig( FeatureObserverType.IS_READY, kwargs={"feature_types": [FeatureType.JOBS]}, - ) + ), + DispatcherObserverConfig( + FeatureObserverType.IS_COMPLETED, + kwargs={"feature_types": [FeatureType.OPERATIONS]}, + ), ] env = SingleJobShopGraphEnv( @@ -233,7 +241,11 @@ def multi_job_shop_graph_env() -> MultiJobShopGraphEnv: DispatcherObserverConfig( FeatureObserverType.IS_READY, kwargs={"feature_types": [FeatureType.JOBS]}, - ) + ), + DispatcherObserverConfig( + FeatureObserverType.IS_COMPLETED, + kwargs={"feature_types": [FeatureType.OPERATIONS]}, + ), ] env = MultiJobShopGraphEnv( @@ -348,3 +360,10 @@ def two_machines_instance() -> JobShopInstance: # Two jobs, each with one operation on different machines jobs = [[Operation(0, 5)], [Operation(1, 3)]] return JobShopInstance(jobs, name="TwoMachines") + + +@pytest.fixture +def first_ten_benchmark_instances() -> list[JobShopInstance]: + """Load the first ten benchmark instances.""" + all_instances = load_all_benchmark_instances() + return list(all_instances.values())[:10] diff --git a/tests/dispatching/feature_observers/test_dates_observer.py b/tests/dispatching/feature_observers/test_dates_observer.py index c707b185..23c27b76 100644 --- a/tests/dispatching/feature_observers/test_dates_observer.py +++ b/tests/dispatching/feature_observers/test_dates_observer.py @@ -70,9 +70,7 @@ def test_update_features(dispatcher_with_extras: Dispatcher): ) while not dispatcher_with_extras.schedule.is_complete(): - op = ( - dispatcher_with_extras.available_operations()[0] - ) + op = dispatcher_with_extras.available_operations()[0] dispatcher_with_extras.dispatch(op) current_time = dispatcher_with_extras.current_time() expected_features = initial_features - current_time diff --git a/tests/dispatching/feature_observers/test_feature_observer.py b/tests/dispatching/feature_observers/test_feature_observer.py index 93d27899..0237c6c0 100644 --- a/tests/dispatching/feature_observers/test_feature_observer.py +++ b/tests/dispatching/feature_observers/test_feature_observer.py @@ -331,7 +331,7 @@ def test_str_method_multiple_features(self, dispatcher): def test_class_attributes(self): """Test class attributes have correct default values.""" # pylint: disable=protected-access - assert FeatureObserver._is_singleton is False + assert FeatureObserver._is_singleton is False assert FeatureObserver._feature_sizes == 1 assert FeatureObserver._supported_feature_types == list(FeatureType) diff --git a/tests/graphs/test_build_disjunctive_graph.py b/tests/graphs/test_build_disjunctive_graph.py index 8e5f1ca0..850bc3a5 100644 --- a/tests/graphs/test_build_disjunctive_graph.py +++ b/tests/graphs/test_build_disjunctive_graph.py @@ -10,12 +10,14 @@ def test_disjunctive_edges_addition(example_job_shop_instance): continue for node1, node2 in itertools.combinations(machine_operations, 2): assert ( - graph.graph.has_edge(node1, node2) - and graph.graph[node1][node2]["type"] == EdgeType.DISJUNCTIVE + graph.get_networkx_graph().has_edge(node1, node2) + and graph.get_networkx_graph()[node1][node2]["type"][1] + == EdgeType.DISJUNCTIVE.name ) assert ( - graph.graph.has_edge(node2, node1) - and graph.graph[node2][node1]["type"] == EdgeType.DISJUNCTIVE + graph.get_networkx_graph().has_edge(node2, node1) + and graph.get_networkx_graph()[node2][node1]["type"][1] + == EdgeType.DISJUNCTIVE.name ) @@ -24,11 +26,13 @@ def test_conjunctive_edges_addition(example_job_shop_instance): for job_operations in graph.nodes_by_job: for i in range(1, len(job_operations)): assert ( - graph.graph.has_edge(job_operations[i - 1], job_operations[i]) - and graph.graph[job_operations[i - 1]][job_operations[i]][ - "type" - ] - == EdgeType.CONJUNCTIVE + graph.get_networkx_graph().has_edge( + job_operations[i - 1], job_operations[i] + ) + and graph.get_networkx_graph()[job_operations[i - 1]][ + job_operations[i] + ]["type"][1] + == EdgeType.CONJUNCTIVE.name ) @@ -54,12 +58,12 @@ def test_source_and_sink_edges_addition(example_job_shop_instance): ) for job_operations in graph.nodes_by_job: assert ( - graph.graph.has_edge(source, job_operations[0]) - and graph.graph[source][job_operations[0]]["type"] - == EdgeType.CONJUNCTIVE + graph.get_networkx_graph().has_edge(source, job_operations[0]) + and graph.get_networkx_graph()[source][job_operations[0]]["type"][1] + == EdgeType.CONJUNCTIVE.name ) assert ( - graph.graph.has_edge(job_operations[-1], sink) - and graph.graph[job_operations[-1]][sink]["type"] - == EdgeType.CONJUNCTIVE + graph.get_networkx_graph().has_edge(job_operations[-1], sink) + and graph.get_networkx_graph()[job_operations[-1]][sink]["type"][1] + == EdgeType.CONJUNCTIVE.name ) diff --git a/tests/graphs/test_build_resource_task_graphs.py b/tests/graphs/test_build_resource_task_graphs.py index 774b16dd..b4244b07 100644 --- a/tests/graphs/test_build_resource_task_graphs.py +++ b/tests/graphs/test_build_resource_task_graphs.py @@ -4,12 +4,10 @@ build_resource_task_graph, ) from job_shop_lib import JobShopInstance -from job_shop_lib.benchmarking import load_all_benchmark_instances -def test_expected_num_nodes_complete(): - benchmark_instances = load_all_benchmark_instances() - for instance in list(benchmark_instances.values())[:10]: +def test_expected_num_nodes_complete(first_ten_benchmark_instances): + for instance in first_ten_benchmark_instances: graph = build_complete_resource_task_graph(instance) expected_num_nodes = get_expected_num_nodes_for_complete_graph( instance @@ -17,9 +15,8 @@ def test_expected_num_nodes_complete(): assert len(graph.nodes) == expected_num_nodes -def test_expected_num_nodes_with_jobs(): - benchmark_instances = load_all_benchmark_instances() - for instance in list(benchmark_instances.values())[:10]: +def test_expected_num_nodes_with_jobs(first_ten_benchmark_instances): + for instance in first_ten_benchmark_instances: graph = build_resource_task_graph_with_jobs(instance) expected_num_nodes = get_expected_num_nodes_for_graph_with_jobs( instance @@ -27,17 +24,15 @@ def test_expected_num_nodes_with_jobs(): assert len(graph.nodes) == expected_num_nodes -def test_expected_num_nodes(): - benchmark_instances = load_all_benchmark_instances() - for instance in list(benchmark_instances.values())[:10]: +def test_expected_num_nodes(first_ten_benchmark_instances): + for instance in first_ten_benchmark_instances: graph = build_resource_task_graph(instance) expected_num_nodes = get_expected_num_nodes(instance) assert len(graph.nodes) == expected_num_nodes -def test_expected_num_edges_complete(): - benchmark_instances = load_all_benchmark_instances() - for instance in list(benchmark_instances.values())[:10]: +def test_expected_num_edges_complete(first_ten_benchmark_instances): + for instance in first_ten_benchmark_instances: graph = build_complete_resource_task_graph(instance) expected_num_edges = get_expected_num_edges_for_complete_graph( instance @@ -67,9 +62,8 @@ def test_expected_num_edges_example(example_job_shop_instance): assert graph.num_edges == expected_num_edges -def test_expected_num_edges_with_jobs(): - benchmark_instances = load_all_benchmark_instances() - for instance in list(benchmark_instances.values())[:10]: +def test_expected_num_edges_with_jobs(first_ten_benchmark_instances): + for instance in first_ten_benchmark_instances: graph = build_resource_task_graph_with_jobs(instance) expected_num_edges = get_expected_num_edges_for_graph_with_jobs( instance @@ -77,9 +71,8 @@ def test_expected_num_edges_with_jobs(): assert graph.num_edges == expected_num_edges -def test_expected_num_edges(): - benchmark_instances = load_all_benchmark_instances() - for instance in list(benchmark_instances.values())[:10]: +def test_expected_num_edges(first_ten_benchmark_instances): + for instance in first_ten_benchmark_instances: graph = build_resource_task_graph(instance) expected_num_edges = get_expected_num_edges(instance) assert graph.num_edges == expected_num_edges diff --git a/tests/graphs/test_job_shop_graph.py b/tests/graphs/test_job_shop_graph.py index 5d76fd32..f3b76b66 100644 --- a/tests/graphs/test_job_shop_graph.py +++ b/tests/graphs/test_job_shop_graph.py @@ -1,3 +1,4 @@ +from collections import defaultdict import pytest import networkx as nx @@ -23,18 +24,48 @@ def test_nodes(example_job_shop_instance): graph = JobShopGraph(example_job_shop_instance) add_source_sink_nodes(graph) assert graph.nodes == [ - data["node"] for _, data in graph.graph.nodes(data=True) + data["node"] for _, data in graph.get_networkx_graph().nodes(data=True) ] def test_node_ids(example_job_shop_instance): + """ + Tests that node IDs are correctly formatted tuples, that nodes can be + retrieved by their ID, and that local IDs are sequential per type. + """ graph = JobShopGraph(example_job_shop_instance) add_source_sink_nodes(graph) - # We don't use enumerate here because we want to test if we can - # access the node by its id - for i in range(graph.graph.number_of_nodes()): - assert graph.nodes[i].node_id == i + nodes_by_type = defaultdict(list) + + # Part 1: Verify ID format and that the node is accessible by its ID + for node in graph.nodes: + node_id = node.node_id + + # Assert the ID is a tuple of (str, int) + assert isinstance(node_id, tuple) + assert len(node_id) == 2 + assert isinstance(node_id[0], str) + assert isinstance(node_id[1], int) + + # Assert the node type in the ID matches the node's actual type + assert node_id[0] == node.node_type.name.lower() + + # Assert that we can retrieve the exact same node using its ID + # This tests the `nodes_map` functionality. + assert graph.nodes_map[node_id] is node + + # Group nodes for the next part of the test + nodes_by_type[node.node_type].append(node) + + # Part 2: Verify that local IDs are sequential for each type + for _, nodes_of_that_type in nodes_by_type.items(): + # Extract the local_id (the integer) from each node's tuple ID + local_ids = [node.node_id[1] for node in nodes_of_that_type] + local_ids.sort() + + # Assert that the local IDs form a complete sequence from 0..N-1 + assert local_ids == list(range(len(nodes_of_that_type))) def test_node_types(example_job_shop_instance): @@ -101,40 +132,51 @@ def test_remove_node(example_job_shop_instance): add_conjunctive_edges(graph) add_disjunctive_edges(graph) add_source_sink_edges(graph) + # Dynamically select valid tuple IDs to remove. + # This is more robust than hardcoding integer indices. + op_nodes = graph.nodes_by_type[NodeType.OPERATION] + # Ensure there are enough nodes to run the test + assert len(op_nodes) >= 7, "Test instance needs at least 7 operations" + nodes_to_remove = [ + op_nodes[0], + op_nodes[3], + op_nodes[6], + ] - # Assumption: graph initially has nodes to remove - node_to_remove = graph.nodes[0].node_id - - nodes_to_remove = [0, 3, 6] + # Remove nodes using their correct tuple IDs + for node in nodes_to_remove: + graph.remove_node(node.node_id) - for node_id in nodes_to_remove: - graph.remove_node(node_id) + # Verify the nodes are no longer in the graph's node set + for node in nodes_to_remove: + if node.node_id in graph.nodes_map: + assert ( + graph.nodes_map[node.node_id].operation.operation_id + != node.operation.operation_id + ) - # Verify the node is no longer in the graph - for node_to_remove in nodes_to_remove: - assert node_to_remove not in graph.graph.nodes() + # Verify the `removed_nodes` attribute is updated correctly via the API + for node in nodes_to_remove: + assert graph.is_removed(node) + with pytest.raises(nx.NetworkXError): + # Seeing all edges of removed nodes just returns an empty list, not an error + # So we try to access an edge that should not exist + last_removed_node_id = nodes_to_remove[-1] + graph.get_networkx_graph().remove_edge( + last_removed_node_id, ("SOURCE", 0) + ) + + # This part of the test remains valid as it uses the is_removed() helper graph.remove_isolated_nodes() - # Verify isolated nodes are also removed and that the source node has - # been removed due to the removal of the isolated nodes - isolated_nodes = list(nx.isolates(graph.graph)) + isolated_nodes = list(nx.isolates(graph.get_networkx_graph())) assert not isolated_nodes - source_node = graph.nodes_by_type[NodeType.SOURCE][0] - assert graph.is_removed(source_node) - - # Verify the `removed_nodes` list is updated correctly - for node_to_remove in nodes_to_remove: - assert graph.removed_nodes[node_to_remove] - - # Optional: Check that no edges remain that involve the removed node - with pytest.raises(nx.NetworkXError): - graph.graph.edges(node_to_remove) - # Check the integrity of the remaining graph structure + # Verify the integrity of the remaining graph structure remaining_node_ids = { node.node_id for node in graph.nodes if not graph.is_removed(node) } - for u, v in graph.graph.edges(): + for u, v in graph.get_networkx_graph().edges(): assert u in remaining_node_ids assert v in remaining_node_ids diff --git a/tests/graphs/test_residual_graph_updater.py b/tests/graphs/test_residual_graph_updater.py index 288af5d8..85ece368 100644 --- a/tests/graphs/test_residual_graph_updater.py +++ b/tests/graphs/test_residual_graph_updater.py @@ -17,7 +17,7 @@ def _verify_all_nodes_removed(job_shop_graph: JobShopGraph): for node in job_shop_graph.nodes: assert job_shop_graph.is_removed( - node.node_id + node ), f"Node {node.node_id} was not removed." diff --git a/tests/reinforcement_learning/.DS_Store b/tests/reinforcement_learning/.DS_Store new file mode 100644 index 00000000..5d489114 Binary files /dev/null and b/tests/reinforcement_learning/.DS_Store differ diff --git a/tests/reinforcement_learning/test_multi_job_shop_graph_env.py b/tests/reinforcement_learning/test_multi_job_shop_graph_env.py index 7fe41c03..2dd63d32 100644 --- a/tests/reinforcement_learning/test_multi_job_shop_graph_env.py +++ b/tests/reinforcement_learning/test_multi_job_shop_graph_env.py @@ -2,6 +2,8 @@ import numpy as np +import pytest + from job_shop_lib.reinforcement_learning import ( MultiJobShopGraphEnv, ObservationSpaceKey, @@ -12,36 +14,18 @@ from job_shop_lib.dispatching.feature_observers import ( CompositeFeatureObserver, IsCompletedObserver, + FeatureType, ) +from job_shop_lib.graphs import NodeType from job_shop_lib.graphs.graph_updaters import ResidualGraphUpdater def _random_action(observation: ObservationDict) -> tuple[int, int]: - ready_operations = [] - for operation_id, is_ready in enumerate( - observation[ObservationSpaceKey.JOBS.value].ravel() - ): - if is_ready == 1.0: - ready_operations.append(operation_id) - - operation_id = random.choice(ready_operations) - machine_id = -1 # We can use -1 if each operation can only be scheduled - # in one machine. - return (operation_id, machine_id) - - -def test_consistent_observation_space( - multi_job_shop_graph_env: MultiJobShopGraphEnv, -): - """Tests that the observation space is consistent across multiple - resets.""" - - env = multi_job_shop_graph_env - observation_space = multi_job_shop_graph_env.observation_space - - for _ in range(100): - _ = env.reset() - assert observation_space == env.observation_space + available_operations_with_ids = observation[ + ObservationSpaceKey.ACTION_MASK.value + ] + operation_id, machine_id, _ = random.choice(available_operations_with_ids) + return (int(operation_id), int(machine_id)) def test_observation_space( @@ -50,59 +34,112 @@ def test_observation_space( random.seed(42) env = multi_job_shop_graph_env - observation_space = multi_job_shop_graph_env.observation_space - edge_index_shape = observation_space[ - ObservationSpaceKey.EDGE_INDEX.value - ].shape for _ in range(100): done = False obs, _ = env.reset() + observation_space = env.observation_space assert observation_space.contains(obs) while not done: action = _random_action(obs) obs, _, done, *_ = env.step(action) - assert observation_space.contains(obs) - env.use_padding = False done = False obs, _ = env.reset() + edge_index_shape = [2, 0] + for _, edges in list(obs[ObservationSpaceKey.EDGE_INDEX.value].items()): + edge_index_shape[1] += edges.shape[1] + edge_index_has_changed = False while not done: action = _random_action(obs) obs, _, done, *_ = env.step(action) edge_index = obs[ObservationSpaceKey.EDGE_INDEX.value] - if edge_index.shape != edge_index_shape: + shape = [2, 0] + for edge in edge_index.values(): + shape[1] += edge.shape[1] + if tuple(shape) != edge_index_shape: edge_index_has_changed = True break assert edge_index_has_changed -def test_edge_index_padding( +def test_observation( multi_job_shop_graph_env: MultiJobShopGraphEnv, ): - random.seed(100) + """ + Tests the integrity of the observation space throughout a full episode. + + This test verifies that: + 1. Edge indices in the observation correctly map to active nodes in the graph. + 2. Node features for completed operations are properly zeroed out. + 3. The number of node features matches the current number of nodes in the graph. + 4. The episode concludes with a complete schedule and no available actions. + """ env = multi_job_shop_graph_env + obs, _ = env.reset() + done = False - for _ in range(1): - done = False - obs, _ = env.reset() - while not done: - action = _random_action(obs) - obs, _, done, *_ = env.step(action) - - edge_index = obs[ObservationSpaceKey.EDGE_INDEX.value] - num_edges = env.observation_space[ # type: ignore[index] - ObservationSpaceKey.EDGE_INDEX.value - ].shape[1] - assert edge_index.shape == (2, num_edges) + while not done: + action = _random_action(obs) + obs, _, done, *_ = env.step(action) - padding_mask = edge_index == -1 - if np.any(padding_mask): - # Ensure all padding is at the end - for row in padding_mask: - padding_start = np.argmax(row) - if padding_start > 0: - assert np.all(row[padding_start:]) + # 1. Verify edge indices + edge_index_dict = obs[ObservationSpaceKey.EDGE_INDEX.value] + for edge_type, edges in edge_index_dict.items(): + assert edge_type in env.job_shop_graph.edge_types + assert ( + edges.ndim == 2 and edges.shape[0] == 2 + ), f"Edge index shape mismatch for {edge_type}: {edges.shape}" + # Ensure all node indices in edges are valid and point to active nodes + src_nodes_type, _, dst_nodes_type = edge_type + src_nodes, dst_nodes = edges + + # Calculate the maximum valid ID for source and destination node types + src_type_removed_nodes = env.job_shop_graph.removed_nodes[ + src_nodes_type + ] + dst_type_removed_nodes = env.job_shop_graph.removed_nodes[ + dst_nodes_type + ] + max_src_id = ( + len(src_type_removed_nodes) - sum(src_type_removed_nodes) - 1 + ) + max_dst_id = ( + len(dst_type_removed_nodes) - sum(dst_type_removed_nodes) - 1 + ) + + if edges.size > 0: # Only check if there are edges of this type + assert np.all( + src_nodes <= max_src_id + ), f"Source nodes {src_nodes} exceed max id {max_src_id} for type {src_nodes_type.name}" + assert np.all( + dst_nodes <= max_dst_id + ), f"Destination nodes {dst_nodes} exceed max id {max_dst_id} for type {dst_nodes_type.name}" + + # 2. Verify that features of completed operations are removed (sum to zero) + op_features = obs[ObservationSpaceKey.NODE_FEATURES.value][ + FeatureType.OPERATIONS.value + ] + op_completed_feats_sum = np.sum(op_features) + assert op_completed_feats_sum == 0, ( + "Operation features are not correctly zeroed out for completed nodes. " + f"Sum should be 0 but got {op_completed_feats_sum}" + ) + + # 3. Verify that the number of node features matches the number of active nodes + assert op_features.shape[0] == len( + env.job_shop_graph.nodes_by_type[NodeType.OPERATION] + ), ( + f"Operation features shape mismatch: {op_features.shape[0]} != " + f"{len(env.job_shop_graph.nodes_by_type[NodeType.OPERATION])}" + ) + + # 4. Verify terminal state + assert env.dispatcher.schedule.is_complete() + assert len(obs[ObservationSpaceKey.ACTION_MASK.value]) == 0, ( + "Action mask should be empty at the end of the episode but is not. " + f"Mask: {obs[ObservationSpaceKey.ACTION_MASK.value]}" + ) def test_all_nodes_are_removed( @@ -116,9 +153,9 @@ def test_all_nodes_are_removed( action = _random_action(obs) obs, _, done, *_ = env.step(action) - removed_nodes = obs[ObservationSpaceKey.REMOVED_NODES.value] + removed_nodes = multi_job_shop_graph_env.job_shop_graph.removed_nodes try: - assert np.all(removed_nodes) + assert all(np.all(lst) for lst in removed_nodes.values()) except AssertionError: print(removed_nodes) print(env.instance.to_dict()) diff --git a/tests/reinforcement_learning/test_resource_task_graph_observation.py b/tests/reinforcement_learning/test_resource_task_graph_observation.py deleted file mode 100644 index 40c5f66d..00000000 --- a/tests/reinforcement_learning/test_resource_task_graph_observation.py +++ /dev/null @@ -1,224 +0,0 @@ -import numpy as np -from job_shop_lib.reinforcement_learning import ( - SingleJobShopGraphEnv, - ResourceTaskGraphObservation, -) -from job_shop_lib.exceptions import ValidationError - - -def test_edge_index_dict( - single_env_ft06_resource_task_graph_with_all_features: ( - SingleJobShopGraphEnv - ), -): - env = ResourceTaskGraphObservation( - single_env_ft06_resource_task_graph_with_all_features - ) - obs, info = env.reset() - max_index = env.unwrapped.job_shop_graph.instance.num_operations - edge_index_dict = obs["edge_index_dict"] - _check_that_edge_index_has_been_reindexed(edge_index_dict, max_index) - - done = False - _, machine_id, job_id = info["available_operations_with_ids"][0] - removed_nodes = env.unwrapped.job_shop_graph.removed_nodes - _check_count_of_unique_ids(edge_index_dict, removed_nodes) - while not done: - obs, _, done, _, info = env.step((job_id, machine_id)) - if done: - break - edge_index_dict = obs["edge_index_dict"] - max_index = len(obs["node_features_dict"]["operation"]) - _check_that_edge_index_has_been_reindexed(edge_index_dict, max_index) - _, machine_id, job_id = info["available_operations_with_ids"][0] - _check_count_of_unique_ids(edge_index_dict, removed_nodes) - machine_id = obs["original_ids_dict"]["machine"][machine_id] - - -def test_node_features_dict( - single_env_ft06_resource_task_graph_with_all_features: ( - SingleJobShopGraphEnv - ), -): - env = ResourceTaskGraphObservation( - single_env_ft06_resource_task_graph_with_all_features - ) - obs, info = env.reset() - done = False - _, machine_id, job_id = info["available_operations_with_ids"][0] - removed_nodes = env.unwrapped.job_shop_graph.removed_nodes - _check_number_of_nodes(obs["node_features_dict"], removed_nodes) - while not done: - obs, _, done, _, info = env.step((job_id, machine_id)) - if done: - break - _check_number_of_nodes(obs["node_features_dict"], removed_nodes) - _, machine_id, job_id = info["available_operations_with_ids"][0] - machine_id = obs["original_ids_dict"]["machine"][machine_id] - is_completed_idx_ops = info["feature_names"]["operations"].index( - "IsCompleted" - ) - is_completed_idx_machines = info["feature_names"]["machines"].index( - "IsCompleted" - ) - assert np.all( - obs["node_features_dict"]["operation"][:, is_completed_idx_ops] - == 0 - ) - assert np.all( - obs["node_features_dict"]["machine"][:, is_completed_idx_machines] - == 0 - ) - assert obs["node_features_dict"]["operation"].shape[1] == 10 - - -def test_original_ids_dict( - single_env_ft06_resource_task_graph_with_all_features: ( - SingleJobShopGraphEnv - ), -): - env = ResourceTaskGraphObservation( - single_env_ft06_resource_task_graph_with_all_features - ) - obs, info = env.reset() - done = False - _, machine_id, job_id = info["available_operations_with_ids"][0] - removed_nodes = env.unwrapped.job_shop_graph.removed_nodes - while not done: - _check_original_ids_dict(obs["original_ids_dict"], removed_nodes) - _, machine_id, job_id = info["available_operations_with_ids"][0] - original_machine_id = obs["original_ids_dict"]["machine"][machine_id] - obs, _, done, _, info = env.step((job_id, original_machine_id)) - - -def test_type_ranges( - single_env_ft06_resource_task_graph_with_all_features: ( - SingleJobShopGraphEnv - ), -): - env = ResourceTaskGraphObservation( - single_env_ft06_resource_task_graph_with_all_features - ) - assert "operation" in env.type_ranges - assert "machine" in env.type_ranges - # (it does not have job nodes) - - assert env.type_ranges["operation"] == (0, 36) - assert env.type_ranges["machine"] == (36, 42) - assert len(env.type_ranges) == 2 - - -def test_info( - single_env_ft06_resource_task_graph_with_all_features: ( - SingleJobShopGraphEnv - ), -): - env = ResourceTaskGraphObservation( - single_env_ft06_resource_task_graph_with_all_features - ) - obs, info = env.reset() - done = False - _check_info_ids( - obs["node_features_dict"], - info["available_operations_with_ids"], - obs["original_ids_dict"], - env, - ) - while not done: - action = info["available_operations_with_ids"][0] - _, machine_id, job_id = action - original_machine_id = obs["original_ids_dict"]["machine"][machine_id] - obs, _, done, _, info = env.step((job_id, original_machine_id)) - _check_info_ids( - obs["node_features_dict"], - info["available_operations_with_ids"], - obs["original_ids_dict"], - env, - ) - - -def _check_info_ids( - node_features_dict: dict[str, np.ndarray], - available_actions_ids: list[tuple[int, int, int]], - original_ids_dict: dict[str, np.ndarray], - env: ResourceTaskGraphObservation[SingleJobShopGraphEnv], -): - for i, (node_type, node_features) in enumerate(node_features_dict.items()): - max_id = node_features.shape[0] - for action in available_actions_ids: - assert action[i] < max_id - if node_type == "machine": - original_machine_id = original_ids_dict[node_type][action[i]] - *_, job_id = action - try: - env.unwrapped.validate_action( - (job_id, original_machine_id) - ) - except ValidationError as e: - print(f"machine_id: {action[i]}") - print(f"original_machine_id: {original_machine_id}") - print(f"job_id: {job_id}") - print(f"original_ids_dict: {original_ids_dict}") - raise e - - -def _check_that_edge_index_has_been_reindexed( - edge_index_dict: dict, max_idx: int -): - values_found = np.zeros(max_idx) - for (type1, _, type2), edge_index in edge_index_dict.items(): - assert np.all(edge_index >= 0) - assert np.all(edge_index < max_idx) - if type1 != "operation" and type2 != "operation": - continue - - for value in range(max_idx): - if value in edge_index: - values_found[value] = 1 - assert np.all(values_found == 1) - - -def _check_count_of_unique_ids( - edge_index_dict: dict[tuple[str, str, str], np.ndarray], - removed_nodes: list[bool], -): - number_of_alive_nodes = len(removed_nodes) - sum(removed_nodes) - operation_nodes = np.zeros(36) - machine_nodes = np.zeros(6) - for key, edge_index in edge_index_dict.items(): - node_type1, _, node_type2 = key - if node_type1 == "operation": - operation_nodes[edge_index[0]] = 1 - elif node_type1 == "machine": - machine_nodes[edge_index[0]] = 1 - if node_type2 == "operation": - operation_nodes[edge_index[1]] = 1 - elif node_type2 == "machine": - machine_nodes[edge_index[1]] = 1 - num_unique_ids = operation_nodes.sum() + machine_nodes.sum() - assert num_unique_ids == number_of_alive_nodes - - -def _check_number_of_nodes( - node_features_dict: dict[str, np.ndarray], removed_nodes: list[bool] -): - number_of_alive_nodes = len(removed_nodes) - sum(removed_nodes) - total_nodes = 0 - for node_features in node_features_dict.values(): - total_nodes += node_features.shape[0] - assert total_nodes == number_of_alive_nodes - - -def _check_original_ids_dict( - original_ids_dict: dict[str, np.ndarray], removed_nodes: list[bool] -): - for node_type, original_ids in original_ids_dict.items(): - adjuster = 36 if node_type == "machine" else 0 - for original_id in original_ids: - assert not removed_nodes[original_id + adjuster] - - -if __name__ == "__main__": - import pytest - - pytest.main(["-vv", __file__]) diff --git a/tests/reinforcement_learning/test_single_job_shop_graph_env.py b/tests/reinforcement_learning/test_single_job_shop_graph_env.py index a30e5123..2b58231a 100644 --- a/tests/reinforcement_learning/test_single_job_shop_graph_env.py +++ b/tests/reinforcement_learning/test_single_job_shop_graph_env.py @@ -1,60 +1,70 @@ import random +import pytest + +import gymnasium as gym + import numpy as np + from job_shop_lib.reinforcement_learning import ( SingleJobShopGraphEnv, ObservationSpaceKey, ObservationDict, ) +from job_shop_lib.dispatching.feature_observers import ( + FeatureType, +) + +from job_shop_lib.graphs import NodeType -def random_action(observation: ObservationDict) -> tuple[int, int]: - ready_operations = [] - for operation_id, is_ready in enumerate( - observation[ObservationSpaceKey.JOBS.value].ravel() - ): - if is_ready == 1.0: - ready_operations.append(operation_id) - operation_id = random.choice(ready_operations) - machine_id = -1 # We can use -1 if each operation can only be scheduled - # in one machine. - return (operation_id, machine_id) +def random_action(observation: ObservationDict) -> tuple[int, int]: + available_operations_with_ids = observation[ + ObservationSpaceKey.ACTION_MASK.value + ] + operation_id, machine_id, _ = random.choice(available_operations_with_ids) + return (int(operation_id), int(machine_id)) def test_observation_space( single_job_shop_graph_env_ft06: SingleJobShopGraphEnv, ): env = single_job_shop_graph_env_ft06 - observation_space = single_job_shop_graph_env_ft06.observation_space - edge_index_shape = observation_space[ # type: ignore[index] - ObservationSpaceKey.EDGE_INDEX - ].shape - assert edge_index_shape == (2, env.job_shop_graph.num_edges) + observation_space = env.observation_space + num_edges = env.initial_job_shop_graph.num_edges + edge_index_shape = [2, 0] + for _, space in list( + observation_space[ObservationSpaceKey.EDGE_INDEX].spaces.items() + ): + edge_index_shape[1] += space.shape[1] + assert tuple(edge_index_shape) == (2, num_edges) + done = False obs, _ = env.reset() assert observation_space.contains(obs) while not done: action = random_action(obs) obs, _, done, *_ = env.step(action) - assert observation_space.contains(obs) - - env.use_padding = False done = False obs, _ = env.reset() edge_index_has_changed = False + while not done: action = random_action(obs) obs, _, done, *_ = env.step(action) edge_index = obs[ObservationSpaceKey.EDGE_INDEX.value] - if edge_index.shape != edge_index_shape: + shape = [2, 0] + for edge in edge_index.values(): + shape[1] += edge.shape[1] + if tuple(shape) != edge_index_shape: edge_index_has_changed = True break assert edge_index_has_changed -def test_edge_index_padding( +def test_observation( single_job_shop_graph_env_ft06: SingleJobShopGraphEnv, ): env = single_job_shop_graph_env_ft06 @@ -66,27 +76,55 @@ def test_edge_index_padding( obs, _, done, *_ = env.step(action) edge_index = obs[ObservationSpaceKey.EDGE_INDEX.value] - num_edges = env.observation_space[ # type: ignore[index] - ObservationSpaceKey.EDGE_INDEX.value - ].shape[1] - assert edge_index.shape == (2, num_edges) - - padding_mask = edge_index == -1 - if np.any(padding_mask): - # Ensure all padding is at the end - for row in padding_mask: - padding_start = np.argmax(row) - if padding_start > 0: - assert np.all(row[padding_start:]) + + for edge_type, edges in edge_index.items(): + assert edge_type in env.job_shop_graph.edge_types + assert ( + edges.ndim == 2 and edges.shape[0] == 2 + ), f"Edge index shape mismatch for {edge_type}: {edges.shape}" + src_nodes_type, _, dst_nodes_type = edge_type + src_nodes, dst_nodes = edges + src_type_removed_nodes = env.job_shop_graph.removed_nodes[ + src_nodes_type + ] + dst_type_removed_nodes = env.job_shop_graph.removed_nodes[ + dst_nodes_type + ] + max_src_id = ( + len(src_type_removed_nodes) - sum(src_type_removed_nodes) - 1 + ) + max_dst_id = ( + len(dst_type_removed_nodes) - sum(dst_type_removed_nodes) - 1 + ) + assert np.all( + src_nodes <= max_src_id + ), f"Source nodes {src_nodes} exceed max id {max_src_id}" + assert np.all( + dst_nodes <= max_dst_id + ), f"Destination nodes {dst_nodes} exceed max id {max_dst_id}" + # Using is_completed features for operations, to see features of removed operation nodes are removed + # We sum the operation features, if sum is equal to 0 always, means it is correctly + # removing completed operation nodes from the graph, thus returning correct node_features_dict + op_completed_feats_sum = np.sum( + obs[ObservationSpaceKey.NODE_FEATURES.value][ + FeatureType.OPERATIONS.value + ] + ) + assert ( + op_completed_feats_sum == 0 + ), f"Operation features are not correctly removed, sum should be 0 but got {op_completed_feats_sum}" + assert obs[ObservationSpaceKey.NODE_FEATURES.value][ + FeatureType.OPERATIONS.value + ].shape[0] == len( + env.job_shop_graph.nodes_by_type[NodeType.OPERATION] + ), f"Operation features shape mismatch: {obs[ObservationSpaceKey.NODE_FEATURES.value][FeatureType.OPERATIONS.value].shape[0]} != {len(env.job_shop_graph.nodes_by_type[NodeType.OPERATION])}" assert env.dispatcher.schedule.is_complete() try: - assert np.all(obs[ObservationSpaceKey.REMOVED_NODES.value]) + assert len(obs[ObservationSpaceKey.ACTION_MASK.value]) == 0 except AssertionError: - print(obs[ObservationSpaceKey.REMOVED_NODES.value]) - print(env.instance.to_dict()) - print(env.instance) - print(env.job_shop_graph.nodes) + print("Action mask is not empty but schedule is complete:") + print(obs[ObservationSpaceKey.ACTION_MASK.value]) raise @@ -96,16 +134,22 @@ def test_all_nodes_removed( env = single_job_shop_graph_env_ft06_resource_task obs, _ = env.reset() done = False + print("Initial observation:") + print(obs) + while not done: action = random_action(obs) - obs, _, done, *_ = env.step(action) + obs, _, done, _, info = env.step(action) # type: ignore[call-arg] assert env.dispatcher.schedule.is_complete() - removed_nodes = obs[ObservationSpaceKey.REMOVED_NODES.value] + removed_nodes = env.job_shop_graph.removed_nodes assert env.job_shop_graph is env.graph_updater.job_shop_graph try: - assert np.all(removed_nodes) + print(removed_nodes) + print(removed_nodes.values()) + print([all(value) for value in removed_nodes.values()]) + assert all(all(value) for value in removed_nodes.values()) except AssertionError: print(removed_nodes) print(env.instance.to_dict()) diff --git a/tests/reinforcement_learning/test_utils.py b/tests/reinforcement_learning/test_utils.py index c7516d0c..e865558c 100644 --- a/tests/reinforcement_learning/test_utils.py +++ b/tests/reinforcement_learning/test_utils.py @@ -6,8 +6,8 @@ from job_shop_lib import Operation, JobShopInstance, ScheduledOperation from job_shop_lib.dispatching import Dispatcher from job_shop_lib.reinforcement_learning import ( - add_padding, create_edge_type_dict, + add_padding, map_values, get_optimal_actions, get_deadline_violation_penalty, diff --git a/tests/test_job_shop_instance.py b/tests/test_job_shop_instance.py index 6af2a6c7..d67f1323 100644 --- a/tests/test_job_shop_instance.py +++ b/tests/test_job_shop_instance.py @@ -323,10 +323,10 @@ def test_eq(): def test_repr(job_shop_instance: JobShopInstance): """Test the string representation of JobShopInstance.""" expected_repr = ( - "JobShopInstance(name=TestInstance, num_jobs=2, " - "num_machines=3)" + "JobShopInstance(name=TestInstance, num_jobs=2, " "num_machines=3)" ) assert repr(job_shop_instance) == expected_repr + if __name__ == "__main__": pytest.main(["-vv", __file__]) diff --git a/tests/visualization/baseline/test_default_plot_disjunctive_graph.png b/tests/visualization/baseline/test_default_plot_disjunctive_graph.png index 264d9daf..cf359f41 100644 Binary files a/tests/visualization/baseline/test_default_plot_disjunctive_graph.png and b/tests/visualization/baseline/test_default_plot_disjunctive_graph.png differ diff --git a/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes.png b/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes.png index 4d372c38..fb548d52 100644 Binary files a/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes.png and b/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes.png differ diff --git a/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes_default_machine_colors.png b/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes_default_machine_colors.png index ad1fdd53..59c6406a 100644 Binary files a/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes_default_machine_colors.png and b/tests/visualization/baseline/test_plot_disjunctive_graph_removed_nodes_default_machine_colors.png differ diff --git a/tests/visualization/baseline/test_plot_disjunctive_graph_single_edge_machine_colors.png b/tests/visualization/baseline/test_plot_disjunctive_graph_single_edge_machine_colors.png index 5cdc864b..a22ad66f 100644 Binary files a/tests/visualization/baseline/test_plot_disjunctive_graph_single_edge_machine_colors.png and b/tests/visualization/baseline/test_plot_disjunctive_graph_single_edge_machine_colors.png differ diff --git a/tests/visualization/gantt/test_plot_gantt_chart.py b/tests/visualization/gantt/test_plot_gantt_chart.py index 592b2987..415f7cd3 100644 --- a/tests/visualization/gantt/test_plot_gantt_chart.py +++ b/tests/visualization/gantt/test_plot_gantt_chart.py @@ -5,10 +5,13 @@ from job_shop_lib.visualization.gantt import plot_gantt_chart -@pytest.mark.mpl_image_compare( - style="default", - savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +KWARGS_MPL_IMAGE_COMPARE = { + "style": "default", + "tolerance": 10, + "savefig_kwargs": {"dpi": 300, "bbox_inches": "tight"}, +} + +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_gantt_chart_default(example_schedule: Schedule): fig, ax = plot_gantt_chart(example_schedule) assert isinstance(fig, Figure) @@ -16,10 +19,7 @@ def test_plot_gantt_chart_default(example_schedule: Schedule): return fig -@pytest.mark.mpl_image_compare( - style="default", - savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_gantt_chart_custom_title(example_schedule: Schedule): fig, ax = plot_gantt_chart(example_schedule, title="Custom Title") assert isinstance(fig, Figure) @@ -28,10 +28,7 @@ def test_plot_gantt_chart_custom_title(example_schedule: Schedule): return fig -@pytest.mark.mpl_image_compare( - style="default", - savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_gantt_chart_no_title(example_schedule: Schedule): fig, ax = plot_gantt_chart(example_schedule, title="") assert isinstance(fig, Figure) @@ -40,10 +37,7 @@ def test_plot_gantt_chart_no_title(example_schedule: Schedule): return fig -@pytest.mark.mpl_image_compare( - style="default", - savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_gantt_chart_custom_labels(example_schedule: Schedule): job_labels = ["Job A", "Job B", "Job C"] machine_labels = ["Machine X", "Machine Y", "Machine Z"] @@ -59,16 +53,17 @@ def test_plot_gantt_chart_custom_labels(example_schedule: Schedule): assert ax is not None assert ax.get_xlabel() == "Custom X Label" assert ax.get_ylabel() == "Custom Y Label" - assert ax.get_legend().get_title().get_text() == "Custom Legend" + legend = ax.get_legend() + assert legend is not None + assert legend.get_title().get_text() == "Custom Legend" assert [tick.get_text() for tick in ax.get_yticklabels()] == machine_labels - assert [text.get_text() for text in ax.get_legend().get_texts()] == job_labels + assert [ + text.get_text() for text in legend.get_texts() + ] == job_labels return fig -@pytest.mark.mpl_image_compare( - style="default", - savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_gantt_chart_custom_xlim_and_ticks(example_schedule: Schedule): custom_xlim = 20 num_ticks = 10 @@ -83,10 +78,7 @@ def test_plot_gantt_chart_custom_xlim_and_ticks(example_schedule: Schedule): return fig -@pytest.mark.mpl_image_compare( - style="default", - savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_gantt_chart_different_cmap(example_schedule: Schedule): fig, ax = plot_gantt_chart(example_schedule, cmap_name="plasma") assert isinstance(fig, Figure) diff --git a/tests/visualization/graphs/test_plot_disjunctive_graph.py b/tests/visualization/graphs/test_plot_disjunctive_graph.py index 5ca9813e..40edd751 100644 --- a/tests/visualization/graphs/test_plot_disjunctive_graph.py +++ b/tests/visualization/graphs/test_plot_disjunctive_graph.py @@ -4,14 +4,22 @@ import matplotlib from job_shop_lib import JobShopInstance -from job_shop_lib.graphs import build_disjunctive_graph, NodeType, Node +from job_shop_lib.graphs import ( + build_disjunctive_graph, + NodeType, + Node, +) from job_shop_lib.visualization.graphs import plot_disjunctive_graph from job_shop_lib.exceptions import ValidationError -@pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +KWARGS_MPL_IMAGE_COMPARE = { + "style": "default", + "tolerance": 20, + "savefig_kwargs": {"dpi": 300, "bbox_inches": "tight"}, +} + +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_default_plot_disjunctive_graph( example_job_shop_instance: JobShopInstance, ): @@ -22,7 +30,7 @@ def test_default_plot_disjunctive_graph( @pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} + **KWARGS_MPL_IMAGE_COMPARE ) def test_plot_disjunctive_graph_single_edge_machine_colors( example_job_shop_instance: JobShopInstance, @@ -59,7 +67,7 @@ def test_plot_disjunctive_graph_single_edge_machine_colors( @pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} + **KWARGS_MPL_IMAGE_COMPARE ) def test_plot_disjunctive_graph_removed_nodes( example_job_shop_instance: JobShopInstance, @@ -69,22 +77,12 @@ def test_plot_disjunctive_graph_removed_nodes( also using machine_colors and single_edge for disjunctive edges. """ graph = build_disjunctive_graph(example_job_shop_instance) - op_nodes = [ node for node in graph.nodes if node.node_type == NodeType.OPERATION ] - - # Remove 1st, 3rd, 4th operation nodes if they exist - nodes_to_remove_indices = [0, 2, 3] - removed_count = 0 - for index_to_remove in nodes_to_remove_indices: - actual_index = index_to_remove - removed_count - if actual_index < len(op_nodes): - node_to_remove = op_nodes.pop(actual_index) - if not graph.is_removed(node_to_remove.node_id): - graph.remove_node(node_to_remove.node_id) - removed_count += 1 - + to_remove = [op_nodes[i] for i in [0, 2, 3] if i < len(op_nodes)] + for node in to_remove: + graph.remove_node(node.node_id) num_machines = example_job_shop_instance.num_machines cmap = matplotlib.colormaps.get_cmap("plasma") machine_colors = { @@ -111,7 +109,7 @@ def test_plot_disjunctive_graph_removed_nodes( @pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} + **KWARGS_MPL_IMAGE_COMPARE ) def test_plot_disjunctive_graph_removed_nodes_default_machine_colors( example_job_shop_instance: JobShopInstance, @@ -129,7 +127,7 @@ def test_plot_disjunctive_graph_removed_nodes_default_machine_colors( actual_index = index_to_remove - removed_count if actual_index < len(op_nodes): node_to_remove = op_nodes.pop(actual_index) - if not graph.is_removed(node_to_remove.node_id): + if not graph.is_removed(node_to_remove): graph.remove_node(node_to_remove.node_id) removed_count += 1 diff --git a/tests/visualization/graphs/test_plot_resource_task_graph.py b/tests/visualization/graphs/test_plot_resource_task_graph.py index 77e67016..cd08ff3c 100644 --- a/tests/visualization/graphs/test_plot_resource_task_graph.py +++ b/tests/visualization/graphs/test_plot_resource_task_graph.py @@ -11,9 +11,14 @@ ) -@pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +KWARGS_MPL_IMAGE_COMPARE = { + "style": "default", + "tolerance": 10, + "savefig_kwargs": {"dpi": 300, "bbox_inches": "tight"}, +} + + +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_resource_task_graph(example_job_shop_instance): graph = build_resource_task_graph(example_job_shop_instance) fig = plot_resource_task_graph(graph) @@ -21,9 +26,7 @@ def test_plot_resource_task_graph(example_job_shop_instance): return fig -@pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_resource_task_graph_with_jobs(example_job_shop_instance): graph = build_resource_task_graph_with_jobs(example_job_shop_instance) fig = plot_resource_task_graph(graph) @@ -32,7 +35,7 @@ def test_plot_resource_task_graph_with_jobs(example_job_shop_instance): @pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} + **KWARGS_MPL_IMAGE_COMPARE ) def test_plot_complete_resource_task_graph(example_job_shop_instance): graph = build_complete_resource_task_graph(example_job_shop_instance) @@ -41,7 +44,7 @@ def test_plot_complete_resource_task_graph(example_job_shop_instance): @pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} + **KWARGS_MPL_IMAGE_COMPARE ) def test_plot_resource_task_graph_custom_title_legend( example_job_shop_instance, @@ -58,7 +61,7 @@ def test_plot_resource_task_graph_custom_title_legend( @pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} + **KWARGS_MPL_IMAGE_COMPARE ) def test_plot_resource_task_graph_with_jobs_doble_arrow( example_job_shop_instance, @@ -84,9 +87,7 @@ def test_plot_resource_task_graph_with_jobs_doble_arrow( return fig -@pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_resource_task_graph_with_jobs_single_edge_custom_params( example_job_shop_instance, ): @@ -105,9 +106,7 @@ def test_plot_resource_task_graph_with_jobs_single_edge_custom_params( return fig -@pytest.mark.mpl_image_compare( - style="default", savefig_kwargs={"dpi": 300, "bbox_inches": "tight"} -) +@pytest.mark.mpl_image_compare(**KWARGS_MPL_IMAGE_COMPARE) def test_plot_complete_resource_task_graph_custom_shapes_colors_layout( example_job_shop_instance, ): @@ -138,9 +137,7 @@ def test_plot_complete_resource_task_graph_custom_shapes_colors_layout( def test_plot_complete_resource_task_graph_custom_edge_params( example_job_shop_instance, ): - graph = build_complete_resource_task_graph( - example_job_shop_instance - ) + graph = build_complete_resource_task_graph(example_job_shop_instance) fig = plot_resource_task_graph( graph, title="",