@@ -24,7 +24,7 @@ class ER_Diagram(FileWriter):
2424 """
2525 DSI Writer that generates an ER Diagram from the current data in the DSI abstraction
2626 """
27- def __init__ (self , filename , target_table_prefix = None , ** kwargs ):
27+ def __init__ (self , filename , target_table_prefix = None , max_cols = None , ** kwargs ):
2828 """
2929 Initializes the ER Diagram writer
3030
@@ -35,10 +35,15 @@ def __init__(self, filename, target_table_prefix = None, **kwargs):
3535 If provided, filters the ER Diagram to only include tables whose names begin with this prefix.
3636
3737 - Ex: If prefix = "student", only "student__address", "student__math", "student__physics" tables are included
38+
39+ `max_cols` : int, optional, default None
40+ If provided, limits the number of columns displayed for each table in the ER Diagram.
41+ If relational data is included, this must be >= number of primary and foreign keys for a table.
3842 """
3943 super ().__init__ (filename , ** kwargs )
4044 self .output_filename = filename
4145 self .target_table_prefix = target_table_prefix
46+ self .max_cols = max_cols
4247
4348 def get_rows (self , collection ) -> None :
4449 """
@@ -99,7 +104,23 @@ def get_rows(self, collection) -> None:
99104
100105 col_list = tableData .keys ()
101106 if tableName == "dsi_units" :
102- col_list = ["table_name" , "column_and_unit" ]
107+ col_list = ["table_name" , "column_name" , "unit" ]
108+ if self .max_cols is not None :
109+ if "dsi_relations" in collection .keys ():
110+ fk_cols = [t [1 ] for t in collection ["dsi_relations" ]["foreign_key" ] if t [0 ] == tableName ]
111+ pk_cols = [t [1 ] for t in collection ["dsi_relations" ]["primary_key" ] if t [0 ] == tableName ]
112+ rel_cols = set (pk_cols + fk_cols )
113+
114+ if rel_cols :
115+ if len (rel_cols ) > self .max_cols :
116+ return (ValueError , "'max_cols' must be >= to the number of primary/foreign key columns." )
117+ other_cols = [col for col in col_list if col not in rel_cols ]
118+ combined = list (rel_cols ) + other_cols [:self .max_cols - len (rel_cols )]
119+ col_list = [k for k in col_list if k in combined ]
120+ col_list = col_list [:self .max_cols ]
121+ if len (tableData .keys ()) > self .max_cols :
122+ col_list .append ("..." )
123+
103124 curr_row = 0
104125 inner_brace = 0
105126 for col_name in col_list :
@@ -121,9 +142,9 @@ def get_rows(self, collection) -> None:
121142
122143 if "dsi_relations" in collection .keys ():
123144 for f_table , f_col in collection ["dsi_relations" ]["foreign_key" ]:
124- if self .target_table_prefix is not None and self .target_table_prefix not in f_table :
145+ if self .target_table_prefix is not None and f_table is not None and self .target_table_prefix not in f_table :
125146 continue
126- if f_table != None :
147+ if f_table is not None :
127148 foreignIndex = collection ["dsi_relations" ]["foreign_key" ].index ((f_table , f_col ))
128149 fk_string = f"{ f_table } :{ f_col } "
129150 pk_string = f"{ collection ['dsi_relations' ]['primary_key' ][foreignIndex ][0 ]} :{ collection ['dsi_relations' ]['primary_key' ][foreignIndex ][1 ]} "
@@ -137,7 +158,10 @@ def get_rows(self, collection) -> None:
137158 subprocess .run (["dot" , "-T" , file_type [1 :], "-o" , self .output_filename + file_type , self .output_filename + ".dot" ])
138159 os .remove (self .output_filename + ".dot" )
139160 else :
140- dot .render (self .output_filename , cleanup = True )
161+ try :
162+ dot .render (self .output_filename , cleanup = True )
163+ except :
164+ return (EnvironmentError , "Graphviz executable must be downloaded to global environment using sudo or homebrew." )
141165
142166class Csv_Writer (FileWriter ):
143167 """
0 commit comments