Skip to content

Commit 5ec5a15

Browse files
authored
Merge pull request #144 from arachne-threat-intel/dependencies-update
psycopg2 to psycopg
2 parents efb9a79 + d282b03 commit 5ec5a15

File tree

7 files changed

+215
-131
lines changed

7 files changed

+215
-131
lines changed

documentation/database.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ This is the default database engine which Thread uses.
2020

2121
PostgreSQL is available in all Ubuntu versions by default. For other environments, you may refer to the [PostgreSQL website](https://www.postgresql.org/download/) to download it.
2222

23-
A required Python package is `psycopg2`. [Pre-requisites](https://www.psycopg.org/docs/install.html#prerequisites) must be fulfilled. (This package is currently commented-out of the [requirements file](../requirements.txt), you may uncomment it to be included in your requirements-installation step.)
23+
A required Python package is `psycopg`. Please refer to [this link](https://www.psycopg.org/psycopg3/docs/basic/install.html#pure-python-installation) for installation guidance. (This package is currently commented-out of the [requirements file](../requirements.txt), you may uncomment it to be included in your requirements-installation step.)
2424

2525
### 2. Create User
2626

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ requests~=2.32
1818
scikit-learn~=1.3
1919
stix2~=3.0
2020
# Uncomment if using PostgreSQL and have satisfied its requirements
21-
# psycopg2~=2.9
21+
# psycopg~=3.2

spindle

threadcomponents/conf/config.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ taxii-local: taxii-server
1616
# Either 'js-online-src' for fetching JS libraries online or 'js-local-src' to use locally saved JS libraries
1717
js-libraries: js-online-src
1818
# The database backend to use: currently either 'sqlite3' or 'postgresql' (sqlite3 recommended for local-use)
19-
# If using postgresql, requires package psycopg2 (see pre-reqs: https://www.psycopg.org/docs/install.html#prerequisites)
20-
# If not using postgresql, you can (re)move the file database/thread_postgresql.py to avoid installing psycopg2
19+
# If using postgresql, requires package psycopg (see pre-reqs: https://www.psycopg.org/psycopg3/docs/basic/install.html)
20+
# If not using postgresql, you can (re)move the file database/thread_postgresql.py to avoid installing psycopg
2121
db-engine: sqlite3
2222
# If you would like the database to be re-built on launch of Thread
2323
# Ineffective when db-engine = 'postgresql'; if wanted, call `main.py --build-db` separately (before launching Thread)

threadcomponents/database/thread_db.py

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
def find_create_statement_in_schema(schema, table, log_error=True, find_closing_bracket=False):
1414
"""Helper-method to return the start and end positions of an SQL create statement in a given schema."""
1515
# The possible matches when finding the CREATE statement (with a space and opening bracket or just the bracket)
16-
start_statements = ["%s %s (" % (CREATE_BEGIN, table), "%s %s(" % (CREATE_BEGIN, table)]
16+
start_statements = [f"{CREATE_BEGIN} {table} (", f"{CREATE_BEGIN} {table}("]
1717
start_pos, error = None, ValueError()
1818
for start_statement in start_statements:
1919
try:
@@ -26,15 +26,15 @@ def find_create_statement_in_schema(schema, table, log_error=True, find_closing_
2626
# Start-position not found
2727
if not start_pos:
2828
if log_error:
29-
logging.error("Table `%s` missing: given schema has different or missing CREATE statement." % table)
29+
logging.error(f"Table `{table}` missing: given schema has different or missing CREATE statement.")
3030
raise error
3131
# Given a starting position, find where the table's create statement finishes
3232
try:
3333
end_pos = schema[start_pos:].index(CREATE_END)
3434
# End-position not found
3535
except ValueError as e:
3636
if log_error:
37-
logging.error("SQL error: could not find closing `%s` for table `%s` in schema." % (CREATE_END, table))
37+
logging.error(f"SQL error: could not find closing `{CREATE_END}` for table `{table}` in schema.")
3838
raise e
3939
# Consider the end position to be before any foreign key statements; they might not be present so ignore ValueErrors
4040
if not find_closing_bracket:
@@ -98,18 +98,20 @@ def get_function_name(self, func_key, *args, unquote=None):
9898
unquote = unquote or []
9999
# Get the function name according to the mapped_functions dictionary
100100
func_name = self._mapped_functions.get(func_key)
101+
101102
# If there is nothing to retrieve, return None
102103
if func_name is None:
103104
return None
105+
104106
# If we have args, construct the string f(a, b, ...) (where str args except fields and query params are quoted)
105107
if args:
106-
return "%s(%s)" % (
107-
func_name,
108+
return f"{func_name}(%s)" % (
108109
", ".join(
109110
("'%s'" % x if (isinstance(x, str) and (x != self.query_param) and (x not in unquote)) else str(x))
110111
for x in args
111112
),
112113
)
114+
113115
# Else if no args are supplied, just return the function name
114116
else:
115117
return func_name
@@ -131,6 +133,7 @@ def generate_copied_tables(schema=""):
131133
"""Function to return a new schema that has copied structures of report-sentence tables from a given schema."""
132134
# The new schema to build on and return
133135
new_schema = ""
136+
134137
# For each table that we are copying the structure of...
135138
for table in TABLES_WITH_BACKUPS:
136139
# Obtain the start and end positions of the SQL create statement for this table
@@ -140,10 +143,12 @@ def generate_copied_tables(schema=""):
140143
create_statement = schema[start_pos : (end_pos + len(CREATE_END))]
141144
# Add the create statement for this table to the new schema
142145
new_schema += "\n\n" + create_statement
146+
143147
# Now that the new schema has the tables we want copied, replace mention of the table name with '<name>_initial'
144148
# We want all occurrences replaced because of foreign key constraints
145149
for table in TABLES_WITH_BACKUPS:
146-
new_schema = new_schema.replace(table, "%s%s" % (table, BACKUP_TABLE_SUFFIX))
150+
new_schema = new_schema.replace(table, f"{table}{BACKUP_TABLE_SUFFIX}")
151+
147152
# Return the new schema
148153
return new_schema.strip()
149154

@@ -189,48 +194,59 @@ def sql_date_field_to_str(self, sql, field_name_as=None, str_suffix=False):
189194
field_name_pos = sql.rfind(".")
190195
if field_name_pos > -1:
191196
field_name_as = sql[field_name_pos + 1 :]
197+
192198
# New field name is what has been provided, calculated or the original sql given
193199
field_name_as = field_name_as or sql
194200
# If we are adding str at the end, do this (useful when SELECT *, to_char(... to prevent duplicate column names)
201+
195202
if str_suffix:
196203
field_name_as += "_str"
204+
197205
# Construct and return sql statement
198206
converter = self.get_function_name(self.FUNC_DATE_TO_STR, sql, "YYYY-MM-DD", unquote=[sql])
199-
return "%s AS %s" % (converter or sql, field_name_as)
207+
return f"{converter or sql} AS {field_name_as}"
200208

201209
@staticmethod
202210
def _check_method_parameters(table, data, data_allowed_as_none=False, method_name="unspecified"):
203211
"""Function to check parameters passed to CRUD methods."""
204212
# Check the table is a string
205213
if not isinstance(table, str):
206-
raise TypeError("Non-string arg passed for table in ThreadDB.%s(): %s" % (method_name, str(table)))
214+
raise TypeError(f"Non-string arg passed for table in ThreadDB.{method_name}(): {table}")
215+
207216
# Proceed with checks if data is non-None but allowed to be so
208217
if data_allowed_as_none and data is None:
209218
return
219+
210220
# Check values passed to this method are dictionaries (column=value key-val pairs)
211221
if not isinstance(data, dict):
212-
raise TypeError("Non-dictionary arg passed in ThreadDB.%s(table=%s): %s" % (method_name, table, str(data)))
222+
raise TypeError(f"Non-dictionary arg passed in ThreadDB.{method_name}(table={table}): {data}")
223+
213224
# If the data is not allowed to be None (or empty), check data has been provided
214225
if (not data_allowed_as_none) and (not len(data)):
215-
raise ValueError("Non-empty-dictionary must be passed in ThreadDB.%s(table=%s)" % (method_name, table))
226+
raise ValueError(f"Non-empty-dictionary must be passed in ThreadDB.{method_name}(table={table})")
216227

217228
async def get(self, table, equal=None, not_equal=None, order_by_asc=None, order_by_desc=None):
218229
"""Method to return values from a db table optionally based on equals or not-equals criteria."""
219230
# Check values passed to this method are valid
220231
for param in [equal, not_equal, order_by_asc, order_by_desc]:
221232
# Allow None values as we do checks for this but non-None values should be dictionaries
222233
self._check_method_parameters(table, param, data_allowed_as_none=True, method_name="get")
234+
223235
# Proceed with method
224-
sql = "SELECT * FROM %s" % table
236+
sql = f"SELECT * FROM {table}"
237+
225238
# Define all_params dictionary (for equal and not_equal to be None-checked and combined)
226239
# all_ordering dictionary (for ASC and DESC ordering combined) and qparams list
227240
all_params, all_ordering, qparams = dict(), dict(), []
241+
228242
# Append to all_params equal and not_equal if not None
229243
all_params.update(dict(equal=equal) if equal else {})
230244
all_params.update(dict(not_equal=not_equal) if not_equal else {})
245+
231246
# Do the same for the ordering dictionaries
232247
all_ordering.update(dict(asc=order_by_asc) if order_by_asc else {})
233248
all_ordering.update(dict(desc=order_by_desc) if order_by_desc else {})
249+
234250
# For each of the equal and not_equal parameters, build SQL query
235251
count = 0
236252
for eq, criteria in all_params.items():
@@ -241,12 +257,13 @@ async def get(self, table, equal=None, not_equal=None, order_by_asc=None, order_
241257
sql += " AND" if count > 0 else " WHERE"
242258
if value is None:
243259
# Do a NULL check for the column
244-
sql += " %s IS%s NULL" % (where, " NOT" if eq == "not_equal" else "")
260+
sql += f" {where} IS{' NOT' if eq == 'not_equal' else ''} NULL"
245261
else:
246262
# Add the ! for != if this is a not-equals check
247-
sql += " %s %s= %s" % (where, "!" if eq == "not_equal" else "", self.query_param)
263+
sql += f" {where} {'!' if eq == 'not_equal' else ''}= {self.query_param}"
248264
qparams.append(value)
249265
count += 1
266+
250267
# For each of the ordering parameters, build the ORDER BY clause of the SQL query
251268
count = 0
252269
for order_by, criteria in all_ordering.items():
@@ -258,14 +275,15 @@ async def get(self, table, equal=None, not_equal=None, order_by_asc=None, order_
258275
# If the boolean value for this column to be ordered is True...
259276
if value:
260277
# Add column name and ASC/DESC criteria
261-
sql += " %s %s" % (where, order_by.upper())
278+
sql += f" {where} {order_by.upper()}"
262279
count += 1
280+
263281
# After the SQL query has been formed, execute it
264282
return await self._execute_select(sql, parameters=qparams)
265283

266284
async def get_column_as_list(self, table, column):
267285
"""Method to return a column from a db table as a list."""
268-
return await self.raw_select("SELECT %s FROM %s" % (column, table), single_col=True)
286+
return await self.raw_select(f"SELECT {column} FROM {table}", single_col=True)
269287

270288
async def get_dict_value_as_key(self, column_key, table=None, columns=None, sql=None, sql_params=None):
271289
"""Method to return a dictionary of results where the key is a column's value.
@@ -285,7 +303,7 @@ def on_fetch(results):
285303
# Insert the columns in the SQL statement depending on its type
286304
# Need to add the column-key to the query, so we get the values for that column
287305
if isinstance(columns, str):
288-
sql = sql % (columns + ", " + column_key)
306+
sql = sql % f"{columns}, {column_key}"
289307
elif isinstance(columns, list):
290308
columns.append(column_key)
291309
sql = sql % ", ".join(columns)
@@ -298,7 +316,7 @@ async def initialise_column_names(self):
298316
# We currently only care about storing initial data table columns for INSERT INTO SELECT statements
299317
for table in TABLES_WITH_BACKUPS:
300318
# Access no data but select all columns for the given table
301-
sql = "SELECT * FROM %s LIMIT 0" % table
319+
sql = f"SELECT * FROM {table} LIMIT 0"
302320
# Update map with the list of columns obtained from this SQL statement
303321
self._table_columns[table] = await self._get_column_names(sql)
304322

@@ -345,7 +363,7 @@ async def insert_with_backup(self, table, data, id_field="uid"):
345363
copied_data = dict(data)
346364
copied_data[id_field] = record_id
347365
# Insert the copied data into the backup table
348-
await self.insert("%s%s" % (table, BACKUP_TABLE_SUFFIX), copied_data)
366+
await self.insert(f"{table}{BACKUP_TABLE_SUFFIX}", copied_data)
349367
# Return the ID for the two records
350368
return record_id
351369

@@ -354,10 +372,13 @@ async def update(self, table, where=None, data=None, return_sql=False):
354372
# Check values passed to this method are valid
355373
self._check_method_parameters(table, data, method_name="update")
356374
self._check_method_parameters(table, where, method_name="update")
375+
357376
# The list of query parameters
358377
qparams = []
378+
359379
# Our SQL statement and optional WHERE clause
360380
sql, where_suffix = "UPDATE {} SET".format(table), ""
381+
361382
# Appending the SET terms; keep a count
362383
count = 0
363384
for k, v in data.items():
@@ -372,6 +393,7 @@ async def update(self, table, where=None, data=None, return_sql=False):
372393
# Update qparams for this value to be substituted
373394
qparams.append(v)
374395
count += 1
396+
375397
# Appending the WHERE terms; keep a count
376398
count = 0
377399
for wk, wv in where.items():
@@ -386,34 +408,39 @@ async def update(self, table, where=None, data=None, return_sql=False):
386408
# Update qparams for this value to be substituted
387409
qparams.append(wv)
388410
count += 1
411+
389412
# Finalise WHERE clause if we had items added to it
390413
where_suffix = "" if where_suffix == "" else " WHERE" + where_suffix
414+
391415
# Add the WHERE clause to the SQL statement
392416
sql += where_suffix
393417
if return_sql:
394418
return tuple([sql, tuple(qparams)])
419+
395420
# Run the statement by passing qparams as parameters
396421
return await self._execute_update(sql, qparams)
397422

398423
async def delete(self, table, data, return_sql=False):
399424
"""Method to delete rows from a table of the db."""
400425
# Check values passed to this method are valid
401426
self._check_method_parameters(table, data, method_name="delete")
402-
sql = "DELETE FROM %s" % table
427+
sql = f"DELETE FROM {table}"
403428
qparams = []
429+
404430
# Construct the WHERE clause using the data
405431
count = 0
406432
for k, v in data.items():
407433
# If this is our first criteria we are adding, we need the WHERE keyword, else adding AND
408434
sql += " AND" if count > 0 else " WHERE"
409435
if v is None:
410436
# Do a NULL check for the column
411-
sql += " %s IS NULL" % k
437+
sql += f" {k} IS NULL"
412438
else:
413439
# Add the ! for != if this is a not-equals check
414-
sql += " %s = %s" % (k, self.query_param)
440+
sql += f" {k} = {self.query_param}"
415441
qparams.append(v)
416442
count += 1
443+
417444
if return_sql:
418445
return tuple([sql, tuple(qparams)])
419446
# Run the statement by passing qparams as parameters

0 commit comments

Comments
 (0)