diff --git a/stix2/datastore/relational_db/demo.py b/stix2/datastore/relational_db/demo.py index dec59dfb..303b50e1 100644 --- a/stix2/datastore/relational_db/demo.py +++ b/stix2/datastore/relational_db/demo.py @@ -1,6 +1,7 @@ import datetime as dt from database_backends.postgres_backend import PostgresBackend +from database_backends.sqlite_backend import SQLiteBackend import sys import json @@ -19,6 +20,7 @@ def main(): bundle = stix2.parse(json.load(f), allow_custom=True) store = RelationalDBStore( PostgresBackend("postgresql://localhost/stix-data-sink", force_recreate=True), + # SQLiteBackend("sqlite:///stix-data-sink.db", force_recreate=True), True, None, True, diff --git a/stix2/datastore/relational_db/input_creation.py b/stix2/datastore/relational_db/input_creation.py index 57ace0ff..dbc2c83b 100644 --- a/stix2/datastore/relational_db/input_creation.py +++ b/stix2/datastore/relational_db/input_creation.py @@ -310,7 +310,7 @@ def generate_insert_information( # noqa: F811 ] for elem in stix_object[name]: bindings = { - "id": stix_object["id"], + "id": foreign_key_value, name: db_backend.process_value_for_insert(self.contained, elem), } insert_statements.append(insert(table).values(bindings)) diff --git a/stix2/datastore/relational_db/relational_db_testing.py b/stix2/datastore/relational_db/relational_db_testing.py index d4fdded7..2a1353b4 100644 --- a/stix2/datastore/relational_db/relational_db_testing.py +++ b/stix2/datastore/relational_db/relational_db_testing.py @@ -290,9 +290,9 @@ def test_dictionary(): def main(): store = RelationalDBStore( - # MariaDBBackend(f"mariadb+pymysql://admin:admin@127.0.0.1:3306/rdb", force_recreate=True), + MariaDBBackend(f"mariadb+pymysql://admin:admin@127.0.0.1:3306/rdb", force_recreate=True), # PostgresBackend("postgresql://localhost/stix-data-sink", force_recreate=True), - SQLiteBackend("sqlite:///stix-data-sink.db", force_recreate=True), + # SQLiteBackend("sqlite:///stix-data-sink.db", force_recreate=True), True, None, diff --git a/stix2/datastore/relational_db/table_creation.py b/stix2/datastore/relational_db/table_creation.py index e80043e4..dc486fad 100644 --- a/stix2/datastore/relational_db/table_creation.py +++ b/stix2/datastore/relational_db/table_creation.py @@ -5,9 +5,8 @@ from stix2.datastore.relational_db.add_method import add_method from stix2.datastore.relational_db.utils import ( - SCO_COMMON_PROPERTIES, SDO_COMMON_PROPERTIES, canonicalize_table_name, - determine_column_name, determine_sql_type_from_stix, flat_classes, - get_stix_object_classes, shorten_extension_definition_id + canonicalize_table_name, determine_column_name, determine_core_properties, determine_sql_type_from_stix, + flat_classes, get_stix_object_classes, shorten_extension_definition_id ) from stix2.properties import ( BinaryProperty, BooleanProperty, DictionaryProperty, @@ -16,7 +15,8 @@ ObjectReferenceProperty, Property, ReferenceProperty, StringProperty, TimestampProperty, TypeProperty, ) -from stix2.v21.base import _Extension, _Observable + +from stix2.v21.base import (_Extension, _Observable) from stix2.v21.common import KillChainPhase @@ -809,14 +809,7 @@ def generate_object_table( table_name = shorten_extension_definition_id(table_name) if parent_table_name: table_name = parent_table_name + "_" + table_name - if is_embedded_object: - core_properties = list() - elif schema_name in ["sdo", "sro", "common"]: - core_properties = SDO_COMMON_PROPERTIES - elif schema_name == "sco": - core_properties = SCO_COMMON_PROPERTIES - else: - core_properties = list() + core_properties = determine_core_properties(stix_object_class, is_embedded_object) columns = list() tables = list() if issubclass(stix_object_class, _Observable): diff --git a/stix2/datastore/relational_db/utils.py b/stix2/datastore/relational_db/utils.py index e49d88dd..e782f5a9 100644 --- a/stix2/datastore/relational_db/utils.py +++ b/stix2/datastore/relational_db/utils.py @@ -41,6 +41,16 @@ } +def determine_core_properties(stix_object_class, is_embedded_object): + if is_embedded_object or issubclass(stix_object_class, (_MetaObject, _Extension)): + return list() + elif issubclass(stix_object_class, (_RelationshipObject, _DomainObject)): + return SDO_COMMON_PROPERTIES + elif issubclass(stix_object_class, _Observable): + return SCO_COMMON_PROPERTIES + else: + raise ValueError(f"{stix_object_class} not a STIX object") + def canonicalize_table_name(table_name, schema_name=None): if schema_name: full_name = schema_name + "." + table_name @@ -103,7 +113,6 @@ def get_stix_object_classes(): ) def schema_for(stix_class): - if issubclass(stix_class, _DomainObject): schema_name = "sdo" elif issubclass(stix_class, _RelationshipObject): @@ -116,7 +125,6 @@ def schema_for(stix_class): schema_name = getattr(stix_class, "_applies_to", "sco") else: schema_name = None - return schema_name