diff --git a/stix2/datastore/relational_db/query.py b/stix2/datastore/relational_db/query.py index 14685605..5cdf14cd 100644 --- a/stix2/datastore/relational_db/query.py +++ b/stix2/datastore/relational_db/query.py @@ -1,3 +1,4 @@ +import collections import inspect import sqlalchemy as sa @@ -280,7 +281,15 @@ def _read_kill_chain_phases(stix_id, type_table, metadata, conn): return kill_chain_phases -def _read_dictionary_property(stix_id, type_table, prop_name, prop_instance, metadata, conn): +def _read_dictionary_property( + stix_id, + type_table, + prop_name, + prop_instance, + metadata, + conn, + db_backend +): """ Read a dictionary from a table. @@ -292,23 +301,57 @@ def _read_dictionary_property(stix_id, type_table, prop_name, prop_instance, met :param metadata: SQLAlchemy Metadata object containing all the table information :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database :return: The dictionary, or None if no dictionary entries were found """ dict_table_name = f"{type_table.fullname}_{prop_name}" dict_table = metadata.tables[dict_table_name] if len(prop_instance.valid_types) == 1: - stmt = sa.select( - dict_table.c.name, dict_table.c.value, - ).where( - dict_table.c.id == stix_id, - ) + valid_type = prop_instance.valid_types[0] + + if isinstance(valid_type, stix2.properties.ListProperty): + if db_backend.array_allowed(): + stmt = sa.select( + dict_table.c.name, dict_table.c["values"], + ).where( + dict_table.c.id == stix_id, + ) - results = conn.execute(stmt) - dict_value = dict(results.all()) + results = conn.execute(stmt) + dict_value = dict(results.all()) + + else: + # Dict contains a list, but array columns are not supported. + # So query from an auxiliary table. + list_table_name = f"{dict_table_name}_values" + list_table = metadata.tables[list_table_name] + stmt = sa.select( + dict_table.c.name, list_table.c.value + ).select_from(dict_table).join( + list_table, list_table.c.id == dict_table.c.values + ).where( + dict_table.c.id == stix_id + ) + + results = conn.execute(stmt) + dict_value = collections.defaultdict(list) + for key, value in results: + dict_value[key].append(value) + + else: + stmt = sa.select( + dict_table.c.name, dict_table.c.value, + ).where( + dict_table.c.id == stix_id, + ) + + results = conn.execute(stmt) + dict_value = dict(results.all()) else: - # In this case, we get one column per valid type + # In this case, we get one column per valid type (assume no lists here) type_cols = (col for col in dict_table.c if col.key not in ("id", "name")) stmt = sa.select(dict_table.c.name, *type_cols).where(dict_table.c.id == stix_id) results = conn.execute(stmt) @@ -437,7 +480,15 @@ def _read_embedded_object_list(fk_id, join_table, embedded_type, metadata, conn) return obj_list -def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, metadata, conn): +def _read_complex_property_value( + obj_id, + prop_name, + prop_instance, + obj_table, + metadata, + conn, + db_backend +): """ Read property values which require auxiliary tables to store. These are idiosyncratic and just require a lot of special cases. This function has @@ -456,6 +507,8 @@ def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, me :param metadata: SQLAlchemy Metadata object containing all the table information :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database :return: The property value """ @@ -523,7 +576,15 @@ def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, me elif isinstance(prop_instance, stix2.properties.DictionaryProperty): # ExtensionsProperty/HashesProperty subclasses DictionaryProperty, so # this must come after those - prop_value = _read_dictionary_property(obj_id, obj_table, prop_name, prop_instance, metadata, conn) + prop_value = _read_dictionary_property( + obj_id, + obj_table, + prop_name, + prop_instance, + metadata, + conn, + db_backend + ) elif isinstance(prop_instance, stix2.properties.EmbeddedObjectProperty): prop_value = _read_embedded_object( @@ -611,6 +672,7 @@ def _read_complex_top_level_property_value( type_table, metadata, conn, + db_backend ) return prop_value