Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 73 additions & 11 deletions stix2/datastore/relational_db/query.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import collections
import inspect

import sqlalchemy as sa
Expand Down Expand Up @@ -280,7 +281,15 @@ def _read_kill_chain_phases(stix_id, type_table, metadata, conn):
return kill_chain_phases


def _read_dictionary_property(stix_id, type_table, prop_name, prop_instance, metadata, conn):
def _read_dictionary_property(
stix_id,
type_table,
prop_name,
prop_instance,
metadata,
conn,
db_backend
):
"""
Read a dictionary from a table.

Expand All @@ -292,23 +301,57 @@ def _read_dictionary_property(stix_id, type_table, prop_name, prop_instance, met
:param metadata: SQLAlchemy Metadata object containing all the table
information
:param conn: An SQLAlchemy DB connection
:param db_backend: A backend object with information about how data is
stored in the database
:return: The dictionary, or None if no dictionary entries were found
"""
dict_table_name = f"{type_table.fullname}_{prop_name}"
dict_table = metadata.tables[dict_table_name]

if len(prop_instance.valid_types) == 1:
stmt = sa.select(
dict_table.c.name, dict_table.c.value,
).where(
dict_table.c.id == stix_id,
)
valid_type = prop_instance.valid_types[0]

if isinstance(valid_type, stix2.properties.ListProperty):
if db_backend.array_allowed():
stmt = sa.select(
dict_table.c.name, dict_table.c["values"],
).where(
dict_table.c.id == stix_id,
)

results = conn.execute(stmt)
dict_value = dict(results.all())
results = conn.execute(stmt)
dict_value = dict(results.all())

else:
# Dict contains a list, but array columns are not supported.
# So query from an auxiliary table.
list_table_name = f"{dict_table_name}_values"
list_table = metadata.tables[list_table_name]
stmt = sa.select(
dict_table.c.name, list_table.c.value
).select_from(dict_table).join(
list_table, list_table.c.id == dict_table.c.values
).where(
dict_table.c.id == stix_id
)

results = conn.execute(stmt)
dict_value = collections.defaultdict(list)
for key, value in results:
dict_value[key].append(value)

else:
stmt = sa.select(
dict_table.c.name, dict_table.c.value,
).where(
dict_table.c.id == stix_id,
)

results = conn.execute(stmt)
dict_value = dict(results.all())

else:
# In this case, we get one column per valid type
# In this case, we get one column per valid type (assume no lists here)
type_cols = (col for col in dict_table.c if col.key not in ("id", "name"))
stmt = sa.select(dict_table.c.name, *type_cols).where(dict_table.c.id == stix_id)
results = conn.execute(stmt)
Expand Down Expand Up @@ -437,7 +480,15 @@ def _read_embedded_object_list(fk_id, join_table, embedded_type, metadata, conn)
return obj_list


def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, metadata, conn):
def _read_complex_property_value(
obj_id,
prop_name,
prop_instance,
obj_table,
metadata,
conn,
db_backend
):
"""
Read property values which require auxiliary tables to store. These are
idiosyncratic and just require a lot of special cases. This function has
Expand All @@ -456,6 +507,8 @@ def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, me
:param metadata: SQLAlchemy Metadata object containing all the table
information
:param conn: An SQLAlchemy DB connection
:param db_backend: A backend object with information about how data is
stored in the database
:return: The property value
"""

Expand Down Expand Up @@ -523,7 +576,15 @@ def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, me
elif isinstance(prop_instance, stix2.properties.DictionaryProperty):
# ExtensionsProperty/HashesProperty subclasses DictionaryProperty, so
# this must come after those
prop_value = _read_dictionary_property(obj_id, obj_table, prop_name, prop_instance, metadata, conn)
prop_value = _read_dictionary_property(
obj_id,
obj_table,
prop_name,
prop_instance,
metadata,
conn,
db_backend
)

elif isinstance(prop_instance, stix2.properties.EmbeddedObjectProperty):
prop_value = _read_embedded_object(
Expand Down Expand Up @@ -611,6 +672,7 @@ def _read_complex_top_level_property_value(
type_table,
metadata,
conn,
db_backend
)

return prop_value
Expand Down
Loading