Skip to content

Commit 719be19

Browse files
authored
Merge pull request #185 from lanl/issue184
updated backend type assignment and num cols to display for er diagram
2 parents 62179a0 + a5dde18 commit 719be19

File tree

13 files changed

+131
-111
lines changed

13 files changed

+131
-111
lines changed

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ Copyright and License
5959

6060
This program is open source under the BSD-3 License.
6161

62-
© 2025. Triad National Security, LLC. All rights reserved.
62+
© 2025. Triad National Security, LLC. All rights reserved. LA-UR-25-29245
6363

6464
Redistribution and use in source and binary forms, with or without modification, are permitted
6565
provided that the following conditions are met:

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
exec(open("../dsi/_version.py").read())
99

1010
project = 'DSI'
11-
copyright = '2025, Triad National Security, LLC. All rights reserved.'
11+
copyright = '2025, Triad National Security, LLC. All rights reserved. LA-UR-25-29248'
1212
author = 'The DSI Project team'
1313
release = __version__
1414

dsi/backends/duckdb.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,19 @@ def sql_type(self, input_list):
6868
`return`: str
6969
A string representing the inferred DuckDB data type for the input list.
7070
"""
71-
for item in input_list:
72-
if isinstance(item, int):
73-
return " INTEGER"
74-
elif isinstance(item, float):
75-
return " FLOAT"
76-
elif isinstance(item, str):
77-
return " VARCHAR"
71+
DUCKDB_BIGINT_MIN = -9223372036854775808
72+
DUCKDB_BIGINT_MAX = 9223372036854775807
73+
DUCKDB_INT_MIN = -2147483648
74+
DUCKDB_INT_MAX = 2147483647
75+
76+
if all(isinstance(x, int) for x in input_list if x is not None):
77+
if any(x < DUCKDB_BIGINT_MIN or x > DUCKDB_BIGINT_MAX for x in input_list if x is not None):
78+
return " DOUBLE"
79+
elif any(x < DUCKDB_INT_MIN or x > DUCKDB_INT_MAX for x in input_list if x is not None):
80+
return " BIGINT"
81+
return " INTEGER"
82+
elif all(isinstance(x, float) for x in input_list if x is not None):
83+
return " DOUBLE"
7884
return " VARCHAR"
7985

8086
def duckdb_compatible_name(self, name):
@@ -822,7 +828,7 @@ def summary_helper(self, table_name):
822828
"""
823829
col_info = self.cur.execute(f"PRAGMA table_info({table_name})").fetchall()
824830

825-
numeric_types = {'INTEGER', 'REAL', 'FLOAT', 'NUMERIC', 'DECIMAL', 'DOUBLE'}
831+
numeric_types = {'INTEGER', 'REAL', 'FLOAT', 'NUMERIC', 'DECIMAL', 'DOUBLE', 'BIGINT'}
826832
headers = ['column', 'type', 'min', 'max', 'avg', 'std_dev']
827833
rows = []
828834

dsi/backends/sqlite.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,15 @@ class Sqlite(Filesystem):
4545
"""
4646
runTable = False
4747

48-
def __init__(self, filename):
48+
def __init__(self, filename, **kwargs):
4949
"""
5050
Initializes a SQLite backend with a user inputted filename, and creates other internal variables
5151
"""
5252
self.filename = filename
53-
self.con = sqlite3.connect(filename)
53+
if 'kwargs' in kwargs:
54+
self.con = sqlite3.connect(filename, **kwargs['kwargs'])
55+
else:
56+
self.con = sqlite3.connect(filename)
5457
self.cur = self.con.cursor()
5558
self.runTable = Sqlite.runTable
5659
self.sqlite_keywords = ["ABORT", "ACTION", "ADD", "AFTER", "ALL", "ALTER", "ALWAYS", "ANALYZE", "AND", "AS", "ASC", "ATTACH",
@@ -80,14 +83,16 @@ def sql_type(self, input_list):
8083
`return`: str
8184
A string representing the inferred SQLite data type for the input list.
8285
"""
83-
for item in input_list:
84-
if isinstance(item, int):
85-
return " INTEGER"
86-
elif isinstance(item, float):
86+
SQLITE_INT_MIN = -9223372036854775808
87+
SQLITE_INT_MAX = 9223372036854775807
88+
89+
if all(isinstance(x, int) for x in input_list if x is not None):
90+
if any(x < SQLITE_INT_MIN or x > SQLITE_INT_MAX for x in input_list if x is not None):
8791
return " FLOAT"
88-
elif isinstance(item, str):
89-
return " VARCHAR"
90-
return ""
92+
return " INTEGER"
93+
elif all(isinstance(x, float) for x in input_list if x is not None):
94+
return " FLOAT"
95+
return " VARCHAR"
9196

9297
def sqlite_compatible_name(self, name):
9398
if (name.startswith('"') and name.endswith('"')) or (name.upper() not in self.sqlite_keywords and name.isidentifier()):

dsi/core.py

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ def wrap_in_quotes(value):
872872
self.logger.error(f"Error finding rows due to {return_object[1]}")
873873
raise return_object[0](return_object[1])
874874
elif isinstance(return_object, list) and isinstance(return_object[0], str):
875-
err_msg = f"'{column_name}' appeared in more than one table. Can only do a conditional find if '{column_name}' is in one table"
875+
err_msg = f"'{column_name}' appeared in more than one table. Can only find if '{column_name}' is in one table"
876876
if self.debug_level != 0:
877877
self.logger.warning(err_msg)
878878
print(f"WARNING: {err_msg}")
@@ -1479,28 +1479,6 @@ def index(self, local_loc, remote_loc, isVerbose=False):
14791479
with redirect_stdout(fnull):
14801480
t.load_module('plugin', "Dict", "reader", collection=st_dict, table_name="filesystem")
14811481
t.artifact_handler(interaction_type='ingest')
1482-
1483-
# # Create new filesystem collection with origin and remote locations
1484-
# # Stage data for ingest
1485-
# # Transpose the OrderedDict to a list of row dictionaries
1486-
# num_rows = len(next(iter(st_dict.values()))) # Assume all columns are of equal length
1487-
# rows = []
1488-
1489-
# for i in range(num_rows):
1490-
# row = {col: values[i] for col, values in st_dict.items()}
1491-
# rows.append(row)
1492-
1493-
# # Temporary csv to ingest
1494-
# output_file = '.fs.csv'
1495-
# with open(output_file, mode='w', newline='') as csvfile:
1496-
# writer = csv.DictWriter(csvfile, fieldnames=st_dict.keys())
1497-
# writer.writeheader()
1498-
# writer.writerows(rows)
1499-
1500-
# # Add filesystem table
1501-
# t.load_module('plugin', 'Csv', 'reader', filenames=".fs.csv", table_name="filesystem")
1502-
# #t.load_module('plugin', 'collection_reader', 'reader', st_dict )
1503-
# t.artifact_handler(interaction_type='ingest')
15041482

15051483
t.close()
15061484

dsi/dsi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class DSI():
1717
The DSI Class abstracts Core.Terminal for managing metadata and Core.Sync for data management and movement.
1818
'''
1919

20-
def __init__(self, filename = ".temp.db", backend_name = "Sqlite"):
20+
def __init__(self, filename = ".temp.db", backend_name = "Sqlite", **kwargs):
2121
"""
2222
Initializes DSI by activating a backend for data operations; default is a Sqlite backend for temporary data analysis.
2323
If users specify `filename`, data is saved to a permanent backend file.
@@ -61,7 +61,7 @@ def __init__(self, filename = ".temp.db", backend_name = "Sqlite"):
6161
try:
6262
if backend_name.lower() == 'sqlite':
6363
with redirect_stdout(fnull):
64-
self.t.load_module('backend','Sqlite','back-write', filename=filename)
64+
self.t.load_module('backend','Sqlite','back-write', filename=filename, kwargs = kwargs)
6565
self.backend_name = "sqlite"
6666
elif backend_name.lower() == 'duckdb':
6767
with redirect_stdout(fnull):

dsi/plugins/file_reader.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,22 +746,26 @@ def add_rows(self) -> None:
746746
field_names = []
747747
for element, val in data.items():
748748
if element not in ['authorship', 'data']:
749+
if isinstance(val, list):
750+
val = ",, ".join(val)
749751
if element not in temp_data.keys():
750752
temp_data[element] = [val]
751753
else:
752754
temp_data[element].append(val)
753755
field_names.append(element)
754756
else:
755757
for field, val2 in val.items():
758+
if isinstance(val2, list):
759+
val2 = ",, ".join(val2)
756760
if field not in temp_data.keys():
757761
temp_data[field] = [val2]
758762
else:
759763
temp_data[field].append(val2)
760764
field_names.append(field)
761765

762-
if sorted(field_names) != sorted(["name", "description", "data_uses", "creators", "creation_date",
763-
"la_ur", "owner", "funding", "publisher", "published_date", "origin_location",
764-
"num_simulations", "version", "license", "live_dataset"]):
766+
if sorted(field_names) != sorted(["title", "description", "keywords", "instructions_of_use", "authors",
767+
"release_date", "la_ur", "funding", "rights", "file_types", "num_simulations",
768+
"file_size", "num_files", "dataset_size", "version", "doi"]):
765769
return (ValueError, f"Error in reading {filename} data card. Please ensure all fields included match the template")
766770

767771
self.datacard_data["oceans11_datacard"] = temp_data

dsi/plugins/file_writer.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class ER_Diagram(FileWriter):
2424
"""
2525
DSI Writer that generates an ER Diagram from the current data in the DSI abstraction
2626
"""
27-
def __init__(self, filename, target_table_prefix = None, **kwargs):
27+
def __init__(self, filename, target_table_prefix = None, max_cols = None, **kwargs):
2828
"""
2929
Initializes the ER Diagram writer
3030
@@ -35,10 +35,15 @@ def __init__(self, filename, target_table_prefix = None, **kwargs):
3535
If provided, filters the ER Diagram to only include tables whose names begin with this prefix.
3636
3737
- Ex: If prefix = "student", only "student__address", "student__math", "student__physics" tables are included
38+
39+
`max_cols` : int, optional, default None
40+
If provided, limits the number of columns displayed for each table in the ER Diagram.
41+
If relational data is included, this must be >= number of primary and foreign keys for a table.
3842
"""
3943
super().__init__(filename, **kwargs)
4044
self.output_filename = filename
4145
self.target_table_prefix = target_table_prefix
46+
self.max_cols = max_cols
4247

4348
def get_rows(self, collection) -> None:
4449
"""
@@ -99,7 +104,23 @@ def get_rows(self, collection) -> None:
99104

100105
col_list = tableData.keys()
101106
if tableName == "dsi_units":
102-
col_list = ["table_name", "column_and_unit"]
107+
col_list = ["table_name", "column_name", "unit"]
108+
if self.max_cols is not None:
109+
if "dsi_relations" in collection.keys():
110+
fk_cols = [t[1] for t in collection["dsi_relations"]["foreign_key"] if t[0] == tableName]
111+
pk_cols = [t[1] for t in collection["dsi_relations"]["primary_key"] if t[0] == tableName]
112+
rel_cols = set(pk_cols + fk_cols)
113+
114+
if rel_cols:
115+
if len(rel_cols) > self.max_cols:
116+
return (ValueError, "'max_cols' must be >= to the number of primary/foreign key columns.")
117+
other_cols = [col for col in col_list if col not in rel_cols]
118+
combined = list(rel_cols) + other_cols[:self.max_cols - len(rel_cols)]
119+
col_list = [k for k in col_list if k in combined]
120+
col_list = col_list[:self.max_cols]
121+
if len(tableData.keys()) > self.max_cols:
122+
col_list.append("...")
123+
103124
curr_row = 0
104125
inner_brace = 0
105126
for col_name in col_list:
@@ -121,9 +142,9 @@ def get_rows(self, collection) -> None:
121142

122143
if "dsi_relations" in collection.keys():
123144
for f_table, f_col in collection["dsi_relations"]["foreign_key"]:
124-
if self.target_table_prefix is not None and self.target_table_prefix not in f_table:
145+
if self.target_table_prefix is not None and f_table is not None and self.target_table_prefix not in f_table:
125146
continue
126-
if f_table != None:
147+
if f_table is not None:
127148
foreignIndex = collection["dsi_relations"]["foreign_key"].index((f_table, f_col))
128149
fk_string = f"{f_table}:{f_col}"
129150
pk_string = f"{collection['dsi_relations']['primary_key'][foreignIndex][0]}:{collection['dsi_relations']['primary_key'][foreignIndex][1]}"
@@ -137,7 +158,10 @@ def get_rows(self, collection) -> None:
137158
subprocess.run(["dot", "-T", file_type[1:], "-o", self.output_filename + file_type, self.output_filename + ".dot"])
138159
os.remove(self.output_filename + ".dot")
139160
else:
140-
dot.render(self.output_filename, cleanup=True)
161+
try:
162+
dot.render(self.output_filename, cleanup=True)
163+
except:
164+
return (EnvironmentError, "Graphviz executable must be downloaded to global environment using sudo or homebrew.")
141165

142166
class Csv_Writer(FileWriter):
143167
"""

0 commit comments

Comments
 (0)