33
44import maxplotlib .subfigure .tikz_figure as tf
55
6+ class Node :
7+ def __init__ (self , x , y , label = "" , content = "" , layer = 0 , ** kwargs ):
8+ self .x = x
9+ self .y = y
10+ self .label = label
11+ self .content = content
12+ self .layer = layer
13+ self .options = kwargs
14+
15+ class Path :
16+ def __init__ (
17+ self , nodes , path_actions = [], cycle = False , label = "" , layer = 0 , ** kwargs
18+ ):
19+ self .nodes = nodes
20+ self .path_actions = path_actions
21+ self .cycle = cycle
22+ self .layer = layer
23+ self .label = label
24+ self .options = kwargs
25+
626class LinePlot :
727 def __init__ (self , ** kwargs ):
828 """
@@ -19,6 +39,7 @@ def __init__(self, **kwargs):
1939 """
2040 # Set default values
2141 self ._figsize = kwargs .get ("figsize" , (10 , 6 ))
42+ self ._title = kwargs .get ("title" , None )
2243 self ._caption = kwargs .get ("caption" , None )
2344 self ._description = kwargs .get ("description" , None )
2445 self ._label = kwargs .get ("label" , None )
@@ -31,6 +52,14 @@ def __init__(self, **kwargs):
3152 self .line_data = []
3253 self .layered_line_data = {}
3354
55+ # Initialize lists to hold Node and Path objects
56+ self .nodes = []
57+ self .paths = []
58+ #self.layers = {}
59+
60+ # Counter for unnamed nodes
61+ self ._node_counter = 0
62+
3463 # Scaling
3564 self ._xscale = kwargs .get ("xscale" , 1.0 )
3665 self ._yscale = kwargs .get ("yscale" , 1.0 )
@@ -40,7 +69,7 @@ def __init__(self, **kwargs):
4069 def add_caption (self , caption ):
4170 self ._caption = caption
4271
43- def add_line (self , x_data , y_data , layer = 0 , ** kwargs ):
72+ def add_line (self , x_data , y_data , layer = 0 , plot_type = 'plot' , ** kwargs ):
4473 """
4574 Add a line to the plot.
4675
@@ -54,6 +83,7 @@ def add_line(self, x_data, y_data, layer=0, **kwargs):
5483 "x" : np .array (x_data ),
5584 "y" : np .array (y_data ),
5685 "layer" : layer ,
86+ "plot_type" : plot_type ,
5787 "kwargs" : kwargs ,
5888 }
5989 self .line_data .append (ld )
@@ -80,13 +110,22 @@ def plot_matplotlib(self, ax, layers=None):
80110 if layers and layer_name not in layers :
81111 continue
82112 for line in layer_lines :
83- ax .plot (
84- (line ["x" ] + self ._xshift ) * self ._xscale ,
85- (line ["y" ] + self ._yshift ) * self ._yscale ,
86- ** line ["kwargs" ],
87- )
88- if self ._caption :
89- ax .set_title (self ._caption )
113+ if line ["plot_type" ] == "plot" :
114+ ax .plot (
115+ (line ["x" ] + self ._xshift ) * self ._xscale ,
116+ (line ["y" ] + self ._yshift ) * self ._yscale ,
117+ ** line ["kwargs" ],
118+ )
119+ elif line ["plot_type" ] == "scatter" :
120+ ax .scatter (
121+ (line ["x" ] + self ._xshift ) * self ._xscale ,
122+ (line ["y" ] + self ._yshift ) * self ._yscale ,
123+ ** line ["kwargs" ],
124+ )
125+ # if self._caption:
126+ # ax.set_title(self._caption)
127+ if self ._title :
128+ ax .set_title (self ._title )
90129 if self ._label :
91130 ax .set_ylabel (self ._label )
92131 if self ._xlabel :
@@ -127,6 +166,67 @@ def plot_plotly(self):
127166 traces .append (trace )
128167
129168 return traces
169+
170+ def add_node (self , x , y , label = None , content = "" , layer = 0 , ** kwargs ):
171+ """
172+ Add a node to the TikZ figure.
173+
174+ Parameters:
175+ - x (float): X-coordinate of the node.
176+ - y (float): Y-coordinate of the node.
177+ - label (str, optional): Label of the node. If None, a default label will be assigned.
178+ - **kwargs: Additional TikZ node options (e.g., shape, color).
179+
180+ Returns:
181+ - node (Node): The Node object that was added.
182+ """
183+ if label is None :
184+ label = f"node{ self ._node_counter } "
185+ node = Node (x = x , y = y , label = label , layer = layer , content = content , ** kwargs )
186+ self .nodes .append (node )
187+ if layer in self .layers :
188+ self .layers [layer ].add (node )
189+ else :
190+ self .layers [layer ] = Tikzlayer (layer )
191+ self .layers [layer ].add (node )
192+ self ._node_counter += 1
193+ return node
194+
195+ def add_path (self , nodes , layer = 0 , ** kwargs ):
196+ """
197+ Add a line or path connecting multiple nodes.
198+
199+ Parameters:
200+ - nodes (list of str): List of node names to connect.
201+ - **kwargs: Additional TikZ path options (e.g., style, color).
202+
203+ Examples:
204+ - add_path(['A', 'B', 'C'], color='blue')
205+ Connects nodes A -> B -> C with a blue line.
206+ """
207+ if not isinstance (nodes , list ):
208+ raise ValueError ("nodes parameter must be a list of node names." )
209+
210+ nodes = [
211+ (
212+ node
213+ if isinstance (node , Node )
214+ else (
215+ self .get_node (node )
216+ if isinstance (node , str )
217+ else ValueError (f"Invalid node type: { type (node )} " )
218+ )
219+ )
220+ for node in nodes
221+ ]
222+ path = Path (nodes , ** kwargs )
223+ self .paths .append (path )
224+ if layer in self .layers :
225+ self .layers [layer ].add (path )
226+ else :
227+ self .layers [layer ] = Tikzlayer (layer )
228+ self .layers [layer ].add (path )
229+ return path
130230
131231 # Getter and Setter for figsize
132232 @property
0 commit comments