@@ -509,16 +509,36 @@ def get_g_arrow_type(endpoint):
509509 return graphviz_g
510510
511511 @staticmethod
512- def to_pydot (G : Graph , edges : List [Edge ] | None = None , title : str = "" , dpi : int = 200 ):
512+ def to_pydot (G : Graph , edges : List [Edge ] | None = None , labels : List [str ] | None = None , title : str = "" , dpi : float = 200 ):
513+ '''
514+ Convert a graph object to a DOT object.
515+
516+ Parameters
517+ ----------
518+ G : Graph
519+ A graph object of causal-learn
520+ edges : list, optional (default=None)
521+ Edges list of graph G
522+ labels : list, optional (default=None)
523+ Nodes labels of graph G
524+ title : str, optional (default="")
525+ The name of graph G
526+ dpi : float, optional (default=200)
527+ The dots per inch of dot object
528+ Returns
529+ -------
530+ pydot_g : Dot
531+ '''
513532 pydot_g = pydot .Dot (title , graph_type = "digraph" , fontsize = 18 )
514533 pydot_g .obj_dict ["attributes" ]["dpi" ] = dpi
515534 nodes = G .get_nodes ()
516535 for i , node in enumerate (nodes ):
536+ node_name = labels [i ] if labels is not None and len (labels ) > i else node .get_name ()
517537 pydot_g .add_node (pydot .Node (i , label = node .get_name ()))
518538 if node .get_node_type () == NodeType .LATENT :
519- pydot_g .add_node (pydot .Node (i , label = node . get_name () , shape = 'square' ))
539+ pydot_g .add_node (pydot .Node (i , label = node_name , shape = 'square' ))
520540 else :
521- pydot_g .add_node (pydot .Node (i , label = node . get_name () ))
541+ pydot_g .add_node (pydot .Node (i , label = node_name ))
522542
523543 def get_g_arrow_type (endpoint ):
524544 if endpoint == Endpoint .TAIL :
0 commit comments