diff --git a/stix2/datastore/relational_db/query.py b/stix2/datastore/relational_db/query.py index 5cdf14cd..a1b1619b 100644 --- a/stix2/datastore/relational_db/query.py +++ b/stix2/datastore/relational_db/query.py @@ -30,7 +30,7 @@ def _check_support(stix_id): raise DataSourceError(f"Reading {stix_type} objects is not supported.") -def _tables_for(stix_class, metadata): +def _tables_for(stix_class, metadata, db_backend): """ Get the core and type-specific tables for the given class @@ -41,17 +41,24 @@ def _tables_for(stix_class, metadata): """ # Info about the type-specific table type_table_name = table_name_for(stix_class) - type_schema_name = schema_for(stix_class) - type_table = metadata.tables[f"{type_schema_name}.{type_table_name}"] + type_schema_name = db_backend.schema_for(stix_class) + canon_type_table_name = canonicalize_table_name(type_table_name, type_schema_name) + + type_table = metadata.tables[canon_type_table_name] # Some fixed info about core tables - if type_schema_name == "sco": - core_table_name = "common.core_sco" + if stix2.utils.is_sco(stix_class._type, stix2.DEFAULT_VERSION): + canon_core_table_name = canonicalize_table_name( + "core_sco", db_backend.schema_for_core() + ) + else: # for SROs and SMOs too? - core_table_name = "common.core_sdo" + canon_core_table_name = canonicalize_table_name( + "core_sdo", db_backend.schema_for_core() + ) - core_table = metadata.tables[core_table_name] + core_table = metadata.tables[canon_core_table_name] return core_table, type_table @@ -134,7 +141,7 @@ def _read_hashes(fk_id, hashes_table, conn): return hashes -def _read_external_references(stix_id, metadata, conn): +def _read_external_references(stix_id, metadata, conn, db_backend): """ Read external references from some fixed tables in the common schema. @@ -142,10 +149,22 @@ def _read_external_references(stix_id, metadata, conn): :param metadata: SQLAlchemy Metadata object containing all the table information :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database :return: The external references, as a list of dicts """ - ext_refs_table = metadata.tables["common.external_references"] - ext_refs_hashes_table = metadata.tables["common.external_references_hashes"] + ext_refs_table = metadata.tables[ + canonicalize_table_name( + "external_references", + db_backend.schema_for_core() + ) + ] + ext_refs_hashes_table = metadata.tables[ + canonicalize_table_name( + "external_references_hashes", + db_backend.schema_for_core() + ) + ] ext_refs = [] ext_refs_columns = (col for col in ext_refs_table.c if col.key != "id") @@ -165,29 +184,30 @@ def _read_external_references(stix_id, metadata, conn): return ext_refs -def _read_object_marking_refs(stix_id, stix_type_class, metadata, conn): +def _read_object_marking_refs(stix_id, common_table_kind, metadata, conn, db_backend): """ Read object marking refs from one of a couple special tables in the common schema. :param stix_id: A STIX ID, used to filter table rows - :param stix_type_class: STIXTypeClass enum value, used to determine whether + :param common_table_kind: "sco" or "sdo", used to determine whether to read the table for SDOs or SCOs :param metadata: SQLAlchemy Metadata object containing all the table information :param conn: An SQLAlchemy DB connection + :param db_backend: A backend object with information about how data is + stored in the database :return: The references as a list of strings """ - marking_table_name = "object_marking_refs_" - if stix_type_class is stix2.utils.STIXTypeClass.SCO: - marking_table_name += "sco" - else: - marking_table_name += "sdo" + marking_table_name = canonicalize_table_name( + "object_marking_refs_" + common_table_kind, + db_backend.schema_for_core() + ) # The SCO/SDO object_marking_refs tables are mostly identical; they just # have different foreign key constraints (to different core tables). - marking_table = metadata.tables["common." + marking_table_name] + marking_table = metadata.tables[marking_table_name] stmt = sa.select(marking_table.c.ref_id).where(marking_table.c.id == stix_id) refs = conn.scalars(stmt).all() @@ -195,13 +215,13 @@ def _read_object_marking_refs(stix_id, stix_type_class, metadata, conn): return refs -def _read_granular_markings(stix_id, stix_type_class, metadata, conn, db_backend): +def _read_granular_markings(stix_id, common_table_kind, metadata, conn, db_backend): """ Read granular markings from one of a couple special tables in the common schema. :param stix_id: A STIX ID, used to filter table rows - :param stix_type_class: STIXTypeClass enum value, used to determine whether + :param common_table_kind: "sco" or "sdo", used to determine whether to read the table for SDOs or SCOs :param metadata: SQLAlchemy Metadata object containing all the table information @@ -211,13 +231,11 @@ def _read_granular_markings(stix_id, stix_type_class, metadata, conn, db_backend :return: Granular markings as a list of dicts """ - marking_table_name = "granular_marking_" - if stix_type_class is stix2.utils.STIXTypeClass.SCO: - marking_table_name += "sco" - else: - marking_table_name += "sdo" - - marking_table = metadata.tables["common." + marking_table_name] + marking_table_name = canonicalize_table_name( + "granular_marking_" + common_table_kind, + db_backend.schema_for_core() + ) + marking_table = metadata.tables[marking_table_name] if db_backend.array_allowed(): # arrays allowed: everything combined in the same table @@ -330,7 +348,7 @@ def _read_dictionary_property( stmt = sa.select( dict_table.c.name, list_table.c.value ).select_from(dict_table).join( - list_table, list_table.c.id == dict_table.c.values + list_table, list_table.c.id == dict_table.c["values"] ).where( dict_table.c.id == stix_id ) @@ -606,7 +624,7 @@ def _read_complex_property_value( def _read_complex_top_level_property_value( stix_id, - stix_type_class, + common_table_kind, prop_name, prop_instance, type_table, @@ -620,8 +638,8 @@ def _read_complex_top_level_property_value( reading top-level common properties, which use special fixed tables. :param stix_id: STIX ID of an object to read - :param stix_type_class: The kind of object (SCO, SDO, etc). Which DB - tables to read can depend on this. + :param common_table_kind: Used to find auxiliary common tables, e.g. those + for object markings, granular markings, etc. Either "sco" or "sdo". :param prop_name: The name of the property to read :param prop_instance: A Property (subclass) instance with property config information @@ -637,20 +655,26 @@ def _read_complex_top_level_property_value( # Common properties: these use a fixed set of tables for all STIX objects if prop_name == "external_references": - prop_value = _read_external_references(stix_id, metadata, conn) + prop_value = _read_external_references( + stix_id, + metadata, + conn, + db_backend + ) elif prop_name == "object_marking_refs": prop_value = _read_object_marking_refs( stix_id, - stix_type_class, + common_table_kind, metadata, conn, + db_backend, ) elif prop_name == "granular_markings": prop_value = _read_granular_markings( stix_id, - stix_type_class, + common_table_kind, metadata, conn, db_backend, @@ -659,7 +683,10 @@ def _read_complex_top_level_property_value( # Will apply when array columns are unsupported/disallowed by the backend elif prop_name == "labels": label_table = metadata.tables[ - f"common.core_{stix_type_class.name.lower()}_labels" + canonicalize_table_name( + f"core_{common_table_kind}_labels", + db_backend.schema_for_core() + ) ] prop_value = _read_simple_array(stix_id, "label", label_table, conn) @@ -698,16 +725,10 @@ def read_object(stix_id, metadata, conn, db_backend): stix_type = stix2.utils.get_type_from_id(stix_id) raise DataSourceError("Can't find registered class for type: " + stix_type) - core_table, type_table = _tables_for(stix_class, metadata) - - if type_table.schema == "common": - # Applies to extension-definition SMO, whose data is stored in the - # common schema; it does not get its own. This type class is used to - # determine which common tables to use; its markings are - # in the *_sdo tables. - stix_type_class = stix2.utils.STIXTypeClass.SDO - else: - stix_type_class = stix2.utils.to_enum(type_table.schema, stix2.utils.STIXTypeClass) + core_table, type_table = _tables_for(stix_class, metadata, db_backend) + # Used to find auxiliary common tables, e.g. those for object markings, + # granular markings, etc. + common_table_kind = core_table.name[-3:] simple_props = _read_simple_properties(stix_id, core_table, type_table, conn) if simple_props is None: @@ -721,7 +742,7 @@ def read_object(stix_id, metadata, conn, db_backend): if prop_name not in obj_dict: prop_value = _read_complex_top_level_property_value( stix_id, - stix_type_class, + common_table_kind, prop_name, prop_instance, type_table, @@ -733,5 +754,10 @@ def read_object(stix_id, metadata, conn, db_backend): if prop_value is not None: obj_dict[prop_name] = prop_value - stix_obj = stix_class(**obj_dict, allow_custom=True) + stix_obj = stix2.parse( + obj_dict, + allow_custom=True, + version=stix2.DEFAULT_VERSION + ) + return stix_obj diff --git a/stix2/test/v21/test_datastore_relational_db.py b/stix2/test/v21/test_datastore_relational_db.py index 518ef307..d6d9d4fe 100644 --- a/stix2/test/v21/test_datastore_relational_db.py +++ b/stix2/test/v21/test_datastore_relational_db.py @@ -10,19 +10,49 @@ from stix2.datastore.relational_db.database_backends.postgres_backend import ( PostgresBackend, ) +from stix2.datastore.relational_db.database_backends.sqlite_backend import SQLiteBackend from stix2.datastore.relational_db.relational_db import RelationalDBStore import stix2.properties import stix2.registry import stix2.v21 -_DB_CONNECT_URL = f"postgresql://{os.getenv('POSTGRES_USER', 'postgres')}:{os.getenv('POSTGRES_PASSWORD', 'postgres')}@0.0.0.0:5432/{os.getenv('POSTGRES_DB', 'postgres')}" -store = RelationalDBStore( - PostgresBackend(_DB_CONNECT_URL, True), - True, - None, - True, +@pytest.fixture( + scope="module", + params=["postgresql", "sqlite"] ) +def db_backend(request): + if request.param == "postgresql": + user = os.getenv('POSTGRES_USER', 'postgres') + pass_ = os.getenv('POSTGRES_PASSWORD', 'postgres') + dbname = os.getenv('POSTGRES_DB', 'postgres') + + connect_url = f"postgresql://{user}:{pass_}@0.0.0.0:5432/{dbname}" + backend = PostgresBackend(connect_url, force_recreate=True) + + elif request.param == "sqlite": + connect_url = "sqlite://" # in-memory DB + backend = SQLiteBackend(connect_url, force_recreate=True) + + else: + raise ValueError(request.param) + + return backend + + +@pytest.fixture +def store(db_backend): + store = RelationalDBStore( + db_backend, + True, + None, + True, + ) + + yield store + + store.metadata.drop_all(db_backend.database_connection) + # Artifacts basic_artifact_dict = { @@ -48,7 +78,7 @@ } -def test_basic_artifact(): +def test_basic_artifact(store): artifact_stix_object = stix2.parse(basic_artifact_dict) store.add(artifact_stix_object) read_obj = json.loads(store.get(artifact_stix_object['id']).serialize()) @@ -57,7 +87,7 @@ def test_basic_artifact(): assert basic_artifact_dict[attrib] == read_obj[attrib] -def test_encrypted_artifact(): +def test_encrypted_artifact(store): artifact_stix_object = stix2.parse(encrypted_artifact_dict) store.add(artifact_stix_object) read_obj = json.loads(store.get(artifact_stix_object['id']).serialize()) @@ -77,7 +107,7 @@ def test_encrypted_artifact(): } -def test_autonomous_system(): +def test_autonomous_system(store): as_obj = stix2.parse(as_dict) store.add(as_obj) read_obj = json.loads(store.get(as_obj['id']).serialize()) @@ -102,7 +132,7 @@ def test_autonomous_system(): } -def test_directory(): +def test_directory(store): directory_obj = stix2.parse(directory_dict) store.add(directory_obj) read_obj = json.loads(store.get(directory_obj['id']).serialize()) @@ -123,7 +153,7 @@ def test_directory(): } -def test_domain_name(): +def test_domain_name(store): domain_name_obj = stix2.parse(domain_name_dict) store.add(domain_name_obj) read_obj = json.loads(store.get(domain_name_obj['id']).serialize()) @@ -143,7 +173,7 @@ def test_domain_name(): } -def test_email_addr(): +def test_email_addr(store): email_addr_stix_object = stix2.parse(email_addr_dict) store.add(email_addr_stix_object) read_obj = json.loads(store.get(email_addr_stix_object['id']).serialize()) @@ -228,7 +258,7 @@ def test_email_addr(): } -def test_email_msg(): +def test_email_msg(store): email_msg_stix_object = stix2.parse(email_msg_dict) store.add(email_msg_stix_object) read_obj = json.loads(store.get(email_msg_stix_object['id']).serialize()) @@ -242,7 +272,7 @@ def test_email_msg(): assert email_msg_dict[attrib] == read_obj[attrib] -def test_multipart_email_msg(): +def test_multipart_email_msg(store): multipart_email_msg_stix_object = stix2.parse(multipart_email_msg_dict) store.add(multipart_email_msg_stix_object) read_obj = json.loads(store.get(multipart_email_msg_stix_object['id']).serialize()) @@ -281,7 +311,7 @@ def test_multipart_email_msg(): } -def test_file(): +def test_file(store): file_stix_object = stix2.parse(file_dict) store.add(file_stix_object) read_obj = json.loads(store.get(file_stix_object['id']).serialize()) @@ -309,7 +339,7 @@ def test_file(): } -def test_ipv4(): +def test_ipv4(store): ipv4_stix_object = stix2.parse(ipv4_dict) store.add(ipv4_stix_object) read_obj = store.get(ipv4_stix_object['id']) @@ -318,7 +348,7 @@ def test_ipv4(): assert ipv4_dict[attrib] == read_obj[attrib] -def test_ipv6(): +def test_ipv6(store): ipv6_stix_object = stix2.parse(ipv6_dict) store.add(ipv6_stix_object) read_obj = store.get(ipv6_stix_object['id']) @@ -336,7 +366,7 @@ def test_ipv6(): } -def test_mutex(): +def test_mutex(store): mutex_stix_object = stix2.parse(mutex_dict) store.add(mutex_stix_object) read_obj = store.get(mutex_stix_object['id']) @@ -376,7 +406,7 @@ def test_mutex(): } -def test_network_traffic(): +def test_network_traffic(store): network_traffic_stix_object = stix2.parse(network_traffic_dict) store.add(network_traffic_stix_object) read_obj = store.get(network_traffic_stix_object['id']) @@ -414,7 +444,7 @@ def test_network_traffic(): } -def test_process(): +def test_process(store): process_stix_object = stix2.parse(process_dict) store.add(process_stix_object) read_obj = json.loads(store.get(process_stix_object['id']).serialize()) @@ -438,7 +468,7 @@ def test_process(): } -def test_software(): +def test_software(store): software_stix_object = stix2.parse(software_dict) store.add(software_stix_object) read_obj = json.loads(store.get(software_stix_object['id']).serialize()) @@ -455,7 +485,7 @@ def test_software(): } -def test_url(): +def test_url(store): url_stix_object = stix2.parse(url_dict) store.add(url_stix_object) read_obj = json.loads(store.get(url_stix_object['id']).serialize()) @@ -486,7 +516,7 @@ def test_url(): } -def test_user_account(): +def test_user_account(store): user_account_stix_object = stix2.parse(user_account_dict) store.add(user_account_stix_object) read_obj = json.loads(store.get(user_account_stix_object['id']).serialize()) @@ -526,7 +556,7 @@ def test_user_account(): } -def test_windows_registry(): +def test_windows_registry(store): windows_registry_stix_object = stix2.parse(windows_registry_dict) store.add(windows_registry_stix_object) read_obj = json.loads(store.get(windows_registry_stix_object['id']).serialize()) @@ -584,7 +614,7 @@ def test_windows_registry(): } -def test_basic_x509_certificate(): +def test_basic_x509_certificate(store): basic_x509_certificate_stix_object = stix2.parse(basic_x509_certificate_dict) store.add(basic_x509_certificate_stix_object) read_obj = json.loads(store.get(basic_x509_certificate_stix_object['id']).serialize()) @@ -598,7 +628,7 @@ def test_basic_x509_certificate(): assert basic_x509_certificate_dict[attrib] == read_obj[attrib] -def test_x509_certificate_with_extensions(): +def test_x509_certificate_with_extensions(store): extensions_x509_certificate_stix_object = stix2.parse(extensions_x509_certificate_dict) store.add(extensions_x509_certificate_stix_object) read_obj = json.loads(store.get(extensions_x509_certificate_stix_object['id']).serialize()) @@ -612,12 +642,12 @@ def test_x509_certificate_with_extensions(): assert extensions_x509_certificate_dict[attrib] == read_obj[attrib] -def test_source_get_not_exists(): +def test_source_get_not_exists(store): obj = store.get("identity--00000000-0000-0000-0000-000000000000") assert obj is None -def test_source_no_registration(): +def test_source_no_registration(store): with pytest.raises(DataSourceError): # error, since no registered class can be found store.get("doesnt-exist--a9e52398-3312-4377-90c2-86d49446c0d0") @@ -875,13 +905,13 @@ class TestClass: _unregister(reg_section, TestClass._type, ext_id) -def test_property(object_variation): +def test_property(db_backend, object_variation): """ Try to more exhaustively test many different property configurations: ensure schemas can be created and values can be stored and retrieved. """ rdb_store = RelationalDBStore( - PostgresBackend(_DB_CONNECT_URL, True), + db_backend, True, None, True, @@ -894,8 +924,10 @@ def test_property(object_variation): assert read_obj == object_variation + rdb_store.metadata.drop_all(db_backend.database_connection) -def test_dictionary_property_complex(): + +def test_dictionary_property_complex(db_backend): """ Test a dictionary property with multiple valid_types """ @@ -921,7 +953,7 @@ def test_dictionary_property_complex(): ) rdb_store = RelationalDBStore( - PostgresBackend(_DB_CONNECT_URL, True), + db_backend, True, None, True, @@ -933,8 +965,10 @@ def test_dictionary_property_complex(): read_obj = rdb_store.get(obj["id"]) assert read_obj == obj + rdb_store.metadata.drop_all(db_backend.database_connection) + -def test_extension_definition(): +def test_extension_definition(store): obj = stix2.ExtensionDefinition( created_by_ref="identity--8a5fb7e4-aabe-4635-8972-cbcde1fa4792", labels=["label1", "label2"],