|
18 | 18 | from datetime import timezone
|
19 | 19 | import json
|
20 | 20 | import logging
|
| 21 | +import pickle |
21 | 22 | from typing import Any
|
22 | 23 | from typing import Optional
|
23 | 24 | import uuid
|
24 | 25 |
|
| 26 | +from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import SpannerPickleType |
25 | 27 | from sqlalchemy import Boolean
|
26 | 28 | from sqlalchemy import delete
|
27 | 29 | from sqlalchemy import Dialect
|
@@ -104,6 +106,31 @@ def load_dialect_impl(self, dialect):
|
104 | 106 | return self.impl
|
105 | 107 |
|
106 | 108 |
|
| 109 | +class DynamicPickleType(TypeDecorator): |
| 110 | + """Represents a type that can be pickled.""" |
| 111 | + |
| 112 | + impl = PickleType |
| 113 | + |
| 114 | + def load_dialect_impl(self, dialect): |
| 115 | + if dialect.name == "spanner+spanner": |
| 116 | + return dialect.type_descriptor(SpannerPickleType) |
| 117 | + return self.impl |
| 118 | + |
| 119 | + def process_bind_param(self, value, dialect): |
| 120 | + """Ensures the pickled value is a bytes object before passing it to the database dialect.""" |
| 121 | + if value is not None: |
| 122 | + if dialect.name == "spanner+spanner": |
| 123 | + return pickle.dumps(value) |
| 124 | + return value |
| 125 | + |
| 126 | + def process_result_value(self, value, dialect): |
| 127 | + """Ensures the raw bytes from the database are unpickled back into a Python object.""" |
| 128 | + if value is not None: |
| 129 | + if dialect.name == "spanner+spanner": |
| 130 | + return pickle.loads(value) |
| 131 | + return value |
| 132 | + |
| 133 | + |
107 | 134 | class Base(DeclarativeBase):
|
108 | 135 | """Base class for database tables."""
|
109 | 136 |
|
@@ -209,7 +236,7 @@ class StorageEvent(Base):
|
209 | 236 | PreciseTimestamp, default=func.now()
|
210 | 237 | )
|
211 | 238 | content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
212 |
| - actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType) |
| 239 | + actions: Mapped[MutableDict[str, Any]] = mapped_column(DynamicPickleType) |
213 | 240 |
|
214 | 241 | long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
215 | 242 | Text, nullable=True
|
|
0 commit comments