Skip to content

Commit 0bcda7a

Browse files
committed
add better implementation for sqlite sequence
1 parent ddf36d4 commit 0bcda7a

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

stix2/datastore/relational_db/database_backends/database_backend_base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any
22

33
from sqlalchemy import (
4-
Boolean, CheckConstraint, Float, Integer, String, Text, create_engine,
4+
Boolean, CheckConstraint, Float, Integer, Sequence, String, Text, create_engine,
55
)
66
from sqlalchemy_utils import create_database, database_exists, drop_database
77

@@ -147,3 +147,6 @@ def process_value_for_insert(self, stix_type, value):
147147
def next_id(self, data_sink):
148148
with self.database_connection.begin() as trans:
149149
return trans.execute(data_sink.sequence)
150+
151+
def create_sequence(self, metadata):
152+
return Sequence("my_general_seq", metadata=metadata, start=1, schema=self.schema_for_core())

stix2/datastore/relational_db/database_backends/sqlite_backend.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any
22

3-
from sqlalchemy import Text, event
3+
from sqlalchemy import Table, Column, Text, event, insert, select, update
4+
from sqlalchemy.schema import CreateTable
45

56
from .database_backend_base import DatabaseBackend
67

@@ -10,6 +11,8 @@ class SQLiteBackend(DatabaseBackend):
1011

1112
temp_sequence_count = 0
1213

14+
select_stmt_for_sequence = None
15+
1316
def __init__(self, database_connection_url=default_database_connection_url, force_recreate=False, **kwargs: Any):
1417
super().__init__(database_connection_url, force_recreate=force_recreate, **kwargs)
1518

@@ -60,8 +63,35 @@ def array_allowed():
6063
def create_regex_constraint_clause(self, column_name, pattern):
6164
return f"{column_name} REGEXP {pattern}"
6265

63-
@staticmethod
64-
def next_id(data_sink):
66+
def next_id(self, data_sink):
6567
# hack, which is not reliable, must look for a better solution
66-
SQLiteBackend.temp_sequence_count += 1
67-
return SQLiteBackend.temp_sequence_count
68+
# SQLiteBackend.temp_sequence_count += 1
69+
# return SQLiteBackend.temp_sequence_count
70+
# ALWAYS CALL WITHIN A TRANSACTION?????
71+
72+
# better solution, but probably not best
73+
value = 0
74+
conn = self.database_connection.connect()
75+
for row in conn.execute(self.select_stmt_for_sequence):
76+
value = row[0]
77+
if value == 0:
78+
stmt = insert(data_sink.sequence).values({"id": 1, "value": 0})
79+
conn = self.database_connection.connect()
80+
conn.execute(stmt)
81+
conn.commit()
82+
value += 1
83+
conn.execute(update(data_sink.sequence).where(data_sink.sequence.c.id == 1).values({"value": value}))
84+
conn.commit()
85+
return value
86+
87+
def create_sequence(self, metadata):
88+
# need id column, so update has something to work with (see above)
89+
t = Table("my_general_seq",
90+
metadata,
91+
Column("id", SQLiteBackend.determine_sql_type_for_key_as_int()),
92+
Column("value", SQLiteBackend.determine_sql_type_for_integer_property()),
93+
schema=self.schema_for_core())
94+
CreateTable(t).compile(self.database_connection)
95+
self.select_stmt_for_sequence = select(t.c.value)
96+
return t
97+

stix2/datastore/relational_db/relational_db.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from sqlalchemy import MetaData, delete
2-
from sqlalchemy.schema import CreateTable, Sequence
2+
from sqlalchemy.schema import CreateTable
33

44
from stix2.base import _STIXBase
55
from stix2.datastore import DataSink, DataSource, DataStoreMixin
@@ -138,7 +138,6 @@ def __init__(
138138
create_table_objects(
139139
self.metadata, stix_object_classes,
140140
)
141-
self.sequence = Sequence("my_general_seq", metadata=self.metadata, start=1, schema=db_backend.schema_for_core())
142141

143142
self.allow_custom = allow_custom
144143

@@ -155,6 +154,7 @@ def __init__(
155154
self._instantiate_database(print_sql)
156155

157156
def _instantiate_database(self, print_sql=False):
157+
self.sequence = self.db_backend.create_sequence(self.metadata)
158158
self.metadata.create_all(self.db_backend.database_connection)
159159
if print_sql:
160160
for t in self.metadata.tables.values():

0 commit comments

Comments
 (0)