Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

"""
from alembic import op
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
from sqlalchemy.sql import table, column
import uuid

# revision identifiers, used by Alembic.
revision = 'e8446f481c1e'
Expand All @@ -21,7 +21,7 @@
def upgrade():
# Create provider_credentials table
op.create_table('provider_credentials',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('provider_name', sa.String(length=255), nullable=False),
sa.Column('credential_name', sa.String(length=255), nullable=False),
Expand Down Expand Up @@ -63,7 +63,7 @@ def migrate_existing_providers_data():
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)

provider_credential_table = table('provider_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
Expand All @@ -79,15 +79,15 @@ def migrate_existing_providers_data():

# Query all existing providers data
existing_providers = conn.execute(
sa.select(providers_table.c.id, providers_table.c.tenant_id,
sa.select(providers_table.c.id, providers_table.c.tenant_id,
providers_table.c.provider_name, providers_table.c.encrypted_config,
providers_table.c.created_at, providers_table.c.updated_at)
.where(providers_table.c.encrypted_config.isnot(None))
).fetchall()

# Iterate through each provider and insert into provider_credentials
for provider in existing_providers:
credential_id = str(uuid.uuid4())
credential_id = str(uuidv7())
if not provider.encrypted_config or provider.encrypted_config.strip() == '':
continue

Expand Down Expand Up @@ -134,7 +134,7 @@ def downgrade():

def migrate_data_back_to_providers():
"""Migrate data back from provider_credentials to providers table for downgrade"""

# Define table structure for data manipulation
providers_table = table('providers',
column('id', models.types.StringUUID()),
Expand All @@ -143,7 +143,7 @@ def migrate_data_back_to_providers():
column('encrypted_config', sa.Text()),
column('credential_id', models.types.StringUUID()),
)

provider_credential_table = table('provider_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
Expand All @@ -160,18 +160,18 @@ def migrate_data_back_to_providers():
sa.select(providers_table.c.id, providers_table.c.credential_id)
.where(providers_table.c.credential_id.isnot(None))
).fetchall()

# For each provider, get the credential data and update providers table
for provider in providers_with_credentials:
credential = conn.execute(
sa.select(provider_credential_table.c.encrypted_config)
.where(provider_credential_table.c.id == provider.credential_id)
).fetchone()

if credential:
# Update providers table with encrypted_config from credential
conn.execute(
providers_table.update()
.where(providers_table.c.id == provider.id)
.values(encrypted_config=credential.encrypted_config)
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
Create Date: 2025-08-13 16:05:42.657730

"""
import uuid

from alembic import op
from libs.uuid_utils import uuidv7
import models as models
import sqlalchemy as sa
from sqlalchemy.sql import table, column
Expand All @@ -23,7 +23,7 @@
def upgrade():
# Create provider_model_credentials table
op.create_table('provider_model_credentials',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('provider_name', sa.String(length=255), nullable=False),
sa.Column('model_name', sa.String(length=255), nullable=False),
Expand Down Expand Up @@ -71,7 +71,7 @@ def migrate_existing_provider_models_data():
column('updated_at', sa.DateTime()),
column('credential_id', models.types.StringUUID()),
)

provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('tenant_id', models.types.StringUUID()),
Expand All @@ -90,19 +90,19 @@ def migrate_existing_provider_models_data():

# Query all existing provider_models data with encrypted_config
existing_provider_models = conn.execute(
sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id,
sa.select(provider_models_table.c.id, provider_models_table.c.tenant_id,
provider_models_table.c.provider_name, provider_models_table.c.model_name,
provider_models_table.c.model_type, provider_models_table.c.encrypted_config,
provider_models_table.c.created_at, provider_models_table.c.updated_at)
.where(provider_models_table.c.encrypted_config.isnot(None))
).fetchall()

# Iterate through each provider_model and insert into provider_model_credentials
for provider_model in existing_provider_models:
if not provider_model.encrypted_config or provider_model.encrypted_config.strip() == '':
continue

credential_id = str(uuid.uuid4())
credential_id = str(uuidv7())

# Insert into provider_model_credentials table
conn.execute(
Expand Down Expand Up @@ -148,14 +148,14 @@ def downgrade():

def migrate_data_back_to_provider_models():
"""Migrate data back from provider_model_credentials to provider_models table for downgrade"""

# Define table structure for data manipulation
provider_models_table = table('provider_models',
column('id', models.types.StringUUID()),
column('encrypted_config', sa.Text()),
column('credential_id', models.types.StringUUID()),
)

provider_model_credentials_table = table('provider_model_credentials',
column('id', models.types.StringUUID()),
column('encrypted_config', sa.Text()),
Expand All @@ -169,14 +169,14 @@ def migrate_data_back_to_provider_models():
sa.select(provider_models_table.c.id, provider_models_table.c.credential_id)
.where(provider_models_table.c.credential_id.isnot(None))
).fetchall()

# For each provider_model, get the credential data and update provider_models table
for provider_model in provider_models_with_credentials:
credential = conn.execute(
sa.select(provider_model_credentials_table.c.encrypted_config)
.where(provider_model_credentials_table.c.id == provider_model.credential_id)
).fetchone()

if credential:
# Update provider_models table with encrypted_config from credential
conn.execute(
Expand Down
4 changes: 2 additions & 2 deletions api/models/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ class ProviderCredential(Base):
sa.Index("provider_credential_tenant_provider_idx", "tenant_id", "provider_name"),
)

id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
credential_name: Mapped[str] = mapped_column(String(255), nullable=False)
Expand All @@ -300,7 +300,7 @@ class ProviderModelCredential(Base):
),
)

id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()"))
id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuidv7()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
provider_name: Mapped[str] = mapped_column(String(255), nullable=False)
model_name: Mapped[str] = mapped_column(String(255), nullable=False)
Expand Down
60 changes: 58 additions & 2 deletions api/tests/test_containers_integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import logging
import os
from collections.abc import Generator
from pathlib import Path
from typing import Optional

import pytest
from flask import Flask
from flask.testing import FlaskClient
from sqlalchemy import Engine, text
from sqlalchemy.orm import Session
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import wait_for_logs
Expand Down Expand Up @@ -64,7 +66,7 @@ def start_containers_with_env(self) -> None:
# PostgreSQL is used for storing user data, workflows, and application state
logger.info("Initializing PostgreSQL container...")
self.postgres = PostgresContainer(
image="postgres:16-alpine",
image="postgres:14-alpine",
)
self.postgres.start()
db_host = self.postgres.get_container_host_ip()
Expand Down Expand Up @@ -116,7 +118,7 @@ def start_containers_with_env(self) -> None:
# Start Redis container for caching and session management
# Redis is used for storing session data, cache entries, and temporary data
logger.info("Initializing Redis container...")
self.redis = RedisContainer(image="redis:latest", port=6379)
self.redis = RedisContainer(image="redis:6-alpine", port=6379)
self.redis.start()
redis_host = self.redis.get_container_host_ip()
redis_port = self.redis.get_exposed_port(6379)
Expand Down Expand Up @@ -184,6 +186,57 @@ def stop_containers(self) -> None:
_container_manager = DifyTestContainers()


def _get_migration_dir() -> Path:
conftest_dir = Path(__file__).parent
return conftest_dir.parent.parent / "migrations"


def _get_engine_url(engine: Engine):
try:
return engine.url.render_as_string(hide_password=False).replace("%", "%%")
except AttributeError:
return str(engine.url).replace("%", "%%")


_UUIDv7SQL = r"""
/* Main function to generate a uuidv7 value with millisecond precision */
CREATE FUNCTION uuidv7() RETURNS uuid
AS
$$
-- Replace the first 48 bits of a uuidv4 with the current
-- number of milliseconds since 1970-01-01 UTC
-- and set the "ver" field to 7 by setting additional bits
SELECT encode(
set_bit(
set_bit(
overlay(uuid_send(gen_random_uuid()) placing
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
3)
from 1 for 6),
52, 1),
53, 1), 'hex')::uuid;
$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;

COMMENT ON FUNCTION uuidv7 IS
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';

CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
AS
$$
/* uuid fields: version=0b0111, variant=0b10 */
SELECT encode(
overlay('\x00000000000070008000000000000000'::bytea
placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3)
from 1 for 6),
'hex')::uuid;
$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;

COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0.
As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
"""


def _create_app_with_containers() -> Flask:
"""
Create Flask application configured to use test containers.
Expand Down Expand Up @@ -211,7 +264,10 @@ def _create_app_with_containers() -> Flask:

# Initialize database schema
logger.info("Creating database schema...")

with app.app_context():
with db.engine.connect() as conn, conn.begin():
conn.execute(text(_UUIDv7SQL))
db.create_all()
logger.info("Database schema created successfully")

Expand Down