diff --git a/stix2/datastore/relational_db/query.py b/stix2/datastore/relational_db/query.py index f223852b..637b66e9 100644 --- a/stix2/datastore/relational_db/query.py +++ b/stix2/datastore/relational_db/query.py @@ -99,6 +99,22 @@ def _read_simple_properties(stix_id, core_table, type_table, conn): return obj_dict +def _read_simple_array(fk_id, elt_column_name, array_table, conn): + """ + Read array elements from a given table. + + :param fk_id: A foreign key value used to find the correct array elements + :param elt_column_name: The name of the table column which contains the + array elements + :param array_table: A SQLAlchemy Table object containing the array data + :param conn: An SQLAlchemy DB connection + :return: The array, as a list + """ + stmt = sa.select(array_table.c[elt_column_name]).where(array_table.c.id == fk_id) + refs = conn.scalars(stmt).all() + return refs + + def _read_hashes(fk_id, hashes_table, conn): """ Read hashes from a table. @@ -178,7 +194,7 @@ def _read_object_marking_refs(stix_id, stix_type_class, metadata, conn): return refs -def _read_granular_markings(stix_id, stix_type_class, metadata, conn): +def _read_granular_markings(stix_id, stix_type_class, metadata, conn, db_backend): """ Read granular markings from one of a couple special tables in the common schema. @@ -189,6 +205,8 @@ def _read_granular_markings(stix_id, stix_type_class, metadata, conn): :param metadata: SQLAlchemy Metadata object containing all the table information :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database :return: Granular markings as a list of dicts """ @@ -200,30 +218,43 @@ def _read_granular_markings(stix_id, stix_type_class, metadata, conn): marking_table = metadata.tables["common." + marking_table_name] - stmt = sa.select( - marking_table.c.lang, - marking_table.c.marking_ref, - marking_table.c.selectors, - ).where(marking_table.c.id == stix_id) - - marking_dicts = conn.execute(stmt).mappings().all() - return marking_dicts + if db_backend.array_allowed(): + # arrays allowed: everything combined in the same table + stmt = sa.select( + marking_table.c.lang, + marking_table.c.marking_ref, + marking_table.c.selectors, + ).where(marking_table.c.id == stix_id) + marking_dicts = conn.execute(stmt).mappings().all() -def _read_simple_array(fk_id, elt_column_name, array_table, conn): - """ - Read array elements from a given table. + else: + # arrays not allowed: selectors are in their own table + stmt = sa.select( + marking_table.c.lang, + marking_table.c.marking_ref, + marking_table.c.selectors, + ).where(marking_table.c.id == stix_id) + + marking_dicts = list(conn.execute(stmt).mappings()) + + for idx, marking_dict in enumerate(marking_dicts): + # make a mutable shallow-copy of the row mapping + marking_dicts[idx] = marking_dict = dict(marking_dict) + selector_id = marking_dict.pop("selectors") + + selector_table_name = f"{marking_table.fullname}_selector" + selector_table = metadata.tables[selector_table_name] + + selectors = _read_simple_array( + selector_id, + "selector", + selector_table, + conn + ) + marking_dict["selectors"] = selectors - :param fk_id: A foreign key value used to find the correct array elements - :param elt_column_name: The name of the table column which contains the - array elements - :param array_table: A SQLAlchemy Table object containing the array data - :param conn: An SQLAlchemy DB connection - :return: The array, as a list - """ - stmt = sa.select(array_table.c[elt_column_name]).where(array_table.c.id == fk_id) - refs = conn.scalars(stmt).all() - return refs + return marking_dicts def _read_kill_chain_phases(stix_id, type_table, metadata, conn): @@ -437,10 +468,26 @@ def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, me ref_table = metadata.tables[ref_table_name] prop_value = _read_simple_array(obj_id, "ref_id", ref_table, conn) - elif isinstance(prop_instance.contained, stix2.properties.EnumProperty): - enum_table_name = f"{obj_table.fullname}_{prop_name}" - enum_table = metadata.tables[enum_table_name] - prop_value = _read_simple_array(obj_id, prop_name, enum_table, conn) + elif isinstance(prop_instance.contained, ( + # Most of these list-of-simple-type cases would occur when array + # columns are disabled. + stix2.properties.BinaryProperty, + stix2.properties.BooleanProperty, + stix2.properties.EnumProperty, + stix2.properties.HexProperty, + stix2.properties.IntegerProperty, + stix2.properties.FloatProperty, + stix2.properties.StringProperty, + stix2.properties.TimestampProperty, + )): + array_table_name = f"{obj_table.fullname}_{prop_name}" + array_table = metadata.tables[array_table_name] + prop_value = _read_simple_array( + obj_id, + prop_name, + array_table, + conn + ) elif isinstance(prop_instance.contained, stix2.properties.EmbeddedObjectProperty): join_table_name = f"{obj_table.fullname}_{prop_name}" @@ -494,7 +541,16 @@ def _read_complex_property_value(obj_id, prop_name, prop_instance, obj_table, me return prop_value -def _read_complex_top_level_property_value(stix_id, stix_type_class, prop_name, prop_instance, type_table, metadata, conn): +def _read_complex_top_level_property_value( + stix_id, + stix_type_class, + prop_name, + prop_instance, + type_table, + metadata, + conn, + db_backend +): """ Read property values which require auxiliary tables to store. These require a lot of special cases. This function has additional support for @@ -511,6 +567,8 @@ def _read_complex_top_level_property_value(stix_id, stix_type_class, prop_name, :param metadata: SQLAlchemy Metadata object containing all the table information :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database :return: The property value """ @@ -519,19 +577,44 @@ def _read_complex_top_level_property_value(stix_id, stix_type_class, prop_name, prop_value = _read_external_references(stix_id, metadata, conn) elif prop_name == "object_marking_refs": - prop_value = _read_object_marking_refs(stix_id, stix_type_class, metadata, conn) + prop_value = _read_object_marking_refs( + stix_id, + stix_type_class, + metadata, + conn + ) elif prop_name == "granular_markings": - prop_value = _read_granular_markings(stix_id, stix_type_class, metadata, conn) + prop_value = _read_granular_markings( + stix_id, + stix_type_class, + metadata, + conn, + db_backend + ) + + # Will apply when array columns are unsupported/disallowed by the backend + elif prop_name == "labels": + label_table = metadata.tables[ + f"common.core_{stix_type_class.name.lower()}_labels" + ] + prop_value = _read_simple_array(stix_id, "label", label_table, conn) else: # Other properties use specific table patterns depending on property type - prop_value = _read_complex_property_value(stix_id, prop_name, prop_instance, type_table, metadata, conn) + prop_value = _read_complex_property_value( + stix_id, + prop_name, + prop_instance, + type_table, + metadata, + conn + ) return prop_value -def read_object(stix_id, metadata, conn): +def read_object(stix_id, metadata, conn, db_backend): """ Read a STIX object from the database, identified by a STIX ID. @@ -539,6 +622,8 @@ def read_object(stix_id, metadata, conn): :param metadata: SQLAlchemy Metadata object containing all the table information :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database :return: A STIX object """ _check_support(stix_id) @@ -554,7 +639,7 @@ def read_object(stix_id, metadata, conn): if type_table.schema == "common": # Applies to extension-definition SMO, whose data is stored in the # common schema; it does not get its own. This type class is used to - # determine which markings tables to use; its markings are + # determine which common tables to use; its markings are # in the *_sdo tables. stix_type_class = stix2.utils.STIXTypeClass.SDO else: @@ -578,6 +663,7 @@ def read_object(stix_id, metadata, conn): type_table, metadata, conn, + db_backend ) if prop_value is not None: diff --git a/stix2/datastore/relational_db/relational_db.py b/stix2/datastore/relational_db/relational_db.py index 2dc6009e..ee0a312f 100644 --- a/stix2/datastore/relational_db/relational_db.py +++ b/stix2/datastore/relational_db/relational_db.py @@ -195,10 +195,11 @@ def __init__( Initialize this source. Only one of stix_object_classes and metadata should be given: if the latter is given, assume table schemas are already created. Instances of this class do not create the actual - database tables; see the source/sink for that. + database tables; see the store/sink for that. Args: - database_connection_or_url: An SQLAlchemy engine object, or URL + db_backend: A database backend object + allow_custom: TODO: unused so far *stix_object_classes: STIX object classes to map into table schemas. This can be used to limit which schemas are created, if one is only working with a subset of STIX types. If not given, @@ -230,6 +231,7 @@ def get(self, stix_id, version=None, _composite_filters=None): stix_id, self.metadata, conn, + self.db_backend, ) return stix_obj diff --git a/stix2/test/v21/test_datastore_relational_db.py b/stix2/test/v21/test_datastore_relational_db.py index 5887c066..316303db 100644 --- a/stix2/test/v21/test_datastore_relational_db.py +++ b/stix2/test/v21/test_datastore_relational_db.py @@ -7,6 +7,9 @@ import stix2 from stix2.datastore import DataSourceError +from stix2.datastore.relational_db.database_backends.postgres_backend import ( + PostgresBackend, +) from stix2.datastore.relational_db.relational_db import RelationalDBStore import stix2.properties import stix2.registry @@ -15,7 +18,7 @@ _DB_CONNECT_URL = f"postgresql://{os.getenv('POSTGRES_USER', 'postgres')}:{os.getenv('POSTGRES_PASSWORD', 'postgres')}@0.0.0.0:5432/postgres" store = RelationalDBStore( - _DB_CONNECT_URL, + PostgresBackend(_DB_CONNECT_URL, True), True, None, False, @@ -878,7 +881,7 @@ def test_property(object_variation): ensure schemas can be created and values can be stored and retrieved. """ rdb_store = RelationalDBStore( - _DB_CONNECT_URL, + PostgresBackend(_DB_CONNECT_URL, True), True, None, True, @@ -918,7 +921,7 @@ def test_dictionary_property_complex(): ) rdb_store = RelationalDBStore( - _DB_CONNECT_URL, + PostgresBackend(_DB_CONNECT_URL, True), True, None, True, @@ -934,6 +937,7 @@ def test_dictionary_property_complex(): def test_extension_definition(): obj = stix2.ExtensionDefinition( created_by_ref="identity--8a5fb7e4-aabe-4635-8972-cbcde1fa4792", + labels=["label1", "label2"], name="test", schema="a schema", version="1.2.3",