Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion api/vercel_function.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from app import app
from app import create_app
from config import ProductionConfig


app = create_app(ProductionConfig)
105 changes: 17 additions & 88 deletions app/__init__.py
Original file line number Diff line number Diff line change
@@ -1,99 +1,28 @@
import os
from datetime import timedelta
from flask import Flask

from flask import Flask, jsonify
from flask_jwt_extended import JWTManager
from flask_migrate import Migrate
from flask_sqlalchemy import SQLAlchemy
from dotenv import load_dotenv
from sqlalchemy import MetaData
from flask_smorest import Api
from app.extensions import api, db, jwt, migrate
from config import DevelopmentConfig


def register_blueprints():
def create_app(config_class=DevelopmentConfig):
app = Flask(__name__)
app.config.from_object(config_class)

# initialize extensions
db.init_app(app)
migrate.init_app(app, db)
jwt.init_app(app)
api.init_app(app)

# register blueprints
from app.routes.auth import bp as auth_bp
from app.routes.category import bp as category_bp
from app.routes.subcategory import bp as subcategory_bp
from app.routes.product import bp as product_bp
from app.routes.auth import bp as auth_bp
from app.routes.subcategory import bp as subcategory_bp

api.register_blueprint(category_bp, url_prefix="/categories")
api.register_blueprint(subcategory_bp, url_prefix="/subcategories")
api.register_blueprint(product_bp, url_prefix="/products")
api.register_blueprint(auth_bp, url_prefix="/auth")


app = Flask(__name__)

load_dotenv()

# sqlalchemy
app.config['SQLALCHEMY_DATABASE_URI'] = os.getenv("SQLALCHEMY_DATABASE_URI")
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

# jwt
app.config["JWT_SECRET_KEY"] = os.getenv("JWT_SECRET_KEY")
app.config["JWT_ACCESS_TOKEN_EXPIRES"] = timedelta(hours=3)
app.config["JWT_REFRESH_TOKEN_EXPIRES"] = timedelta(days=3)

# flask-smorest
app.config["API_TITLE"] = "Ecommerce REST API"
app.config["API_VERSION"] = "v1"
app.config["OPENAPI_VERSION"] = "3.0.2"

# flask-smorest openapi swagger
app.config["OPENAPI_URL_PREFIX"] = "/"
app.config["OPENAPI_SWAGGER_UI_PATH"] = "/"
app.config["OPENAPI_SWAGGER_UI_URL"] = "https://cdn.jsdelivr.net/npm/swagger-ui-dist/"

# flask-smorest Swagger UI top level authorize dialog box
app.config["API_SPEC_OPTIONS"] = {
"components": {
"securitySchemes": {
"access_token": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Enter your JWT access token",
},
"refresh_token": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Enter your JWT refresh token",
},
}
}
}

# PostgreSQL-compatible naming convention (to follow the naming convention already used in the DB)
# https://stackoverflow.com/questions/4107915/postgresql-default-constraint-names
naming_convention = {
"ix": "%(table_name)s_%(column_0_name)s_idx", # Indexes
"uq": "%(table_name)s_%(column_0_name)s_key", # Unique constraints
"ck": "%(table_name)s_%(constraint_name)s_check", # Check constraints
"fk": "%(table_name)s_%(column_0_name)s_fkey", # Foreign keys
"pk": "%(table_name)s_pkey" # Primary keys
}
metadata = MetaData(naming_convention=naming_convention)
db = SQLAlchemy(app, metadata=metadata)
migrate = Migrate(app, db)
jwt = JWTManager(app)
api = Api(app)

register_blueprints()


@jwt.expired_token_loader
def expired_token_callback(jwt_header, jwt_payload):
err = "Access token expired. Use your refresh token to get a new one."
if jwt_payload['type'] == 'refresh':
err = "Refresh token expired. Please login again."
return jsonify(code="token_expired", error=err), 401

@jwt.invalid_token_loader
def invalid_token_callback(error):
return jsonify(code="invalid_token", error="Invalid token provided."), 401

@jwt.unauthorized_loader
def missing_token_callback(error):
return jsonify(code="authorization_required", error="JWT needed for this operation. Login, if needed."), 401
return app
43 changes: 43 additions & 0 deletions app/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from flask import jsonify
from flask_jwt_extended import JWTManager
from flask_migrate import Migrate
from flask_smorest import Api
from flask_sqlalchemy import SQLAlchemy
from sqlalchemy import MetaData


# PostgreSQL-compatible naming convention (to follow the naming convention already used in the DB)
# https://stackoverflow.com/questions/4107915/postgresql-default-constraint-names
naming_convention = {
"ix": "%(table_name)s_%(column_0_name)s_idx", # Indexes
"uq": "%(table_name)s_%(column_0_name)s_key", # Unique constraints
"ck": "%(table_name)s_%(constraint_name)s_check", # Check constraints
"fk": "%(table_name)s_%(column_0_name)s_fkey", # Foreign keys
"pk": "%(table_name)s_pkey", # Primary keys
}
metadata = MetaData(naming_convention=naming_convention)
db = SQLAlchemy(metadata=metadata)
migrate = Migrate(db)
jwt = JWTManager()
api = Api()


@jwt.expired_token_loader
def expired_token_callback(jwt_header, jwt_payload):
err = "Access token expired. Use your refresh token to get a new one."
if jwt_payload["type"] == "refresh":
err = "Refresh token expired. Please login again."
return jsonify(code="token_expired", error=err), 401


@jwt.invalid_token_loader
def invalid_token_callback(error):
return jsonify(code="invalid_token", error="Invalid token provided."), 401


@jwt.unauthorized_loader
def missing_token_callback(error):
return jsonify(
code="authorization_required",
error="JWT needed for this operation. Login, if needed.",
), 401
62 changes: 62 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os
from datetime import timedelta

from dotenv import load_dotenv


load_dotenv()


class Config:
# sqlalchemy
SQLALCHEMY_TRACK_MODIFICATIONS = False

# jwt
JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY")
JWT_ACCESS_TOKEN_EXPIRES = timedelta(hours=3)
JWT_REFRESH_TOKEN_EXPIRES = timedelta(days=3)

# flask-smorest
API_TITLE = "Ecommerce REST API"
API_VERSION = "v1"
OPENAPI_VERSION = "3.0.2"

# flask-smorest openapi swagger
OPENAPI_URL_PREFIX = "/"
OPENAPI_SWAGGER_UI_PATH = "/"
OPENAPI_SWAGGER_UI_URL = "https://cdn.jsdelivr.net/npm/swagger-ui-dist/"

# flask-smorest Swagger UI top level authorize dialog box
API_SPEC_OPTIONS = {
"components": {
"securitySchemes": {
"access_token": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Enter your JWT access token",
},
"refresh_token": {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Enter your JWT refresh token",
},
}
}
}


class DevelopmentConfig(Config):
DEBUG = True
SQLALCHEMY_DATABASE_URI = os.getenv("SQLALCHEMY_DATABASE_URI")


class TestingConfig(Config):
TESTING = True
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
JWT_SECRET_KEY = os.urandom(24).hex()


class ProductionConfig(Config):
SQLALCHEMY_DATABASE_URI = os.getenv("SQLALCHEMY_DATABASE_URI")
4 changes: 3 additions & 1 deletion populate_db.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from faker import Faker
from app import app, db
from app import create_app, db
from app.models import Category, Subcategory, Product, category_subcategory, subcategory_product
import random


app = create_app()
fake = Faker()


def create_categories(num=5):
categories = []
for _ in range(num):
Expand Down
36 changes: 21 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
import os

import pytest

# TODO: Fix hack. Changes the env var before initializing the db for testing
os.environ["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
os.environ["JWT_SECRET_KEY"] = os.urandom(24).hex()

from app import app, db
from app import create_app, db
from config import TestingConfig
from tests import utils


@pytest.fixture
def client():
app.config["TESTING"] = True
with app.test_client() as client:
with app.app_context():
db.create_all()
yield client
with app.app_context():
db.drop_all()
def app():
app = create_app(TestingConfig)

# setup
app_context = app.app_context()
app_context.push()
db.create_all()

Comment on lines +10 to +16
Copy link
Contributor

@coderabbitai coderabbitai bot Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

In‑memory SQLite can reset across connections

Using sqlite:///:memory: without StaticPool may create a fresh DB per connection, causing flaky tests (tables “disappear” between create_all and requests). Use StaticPool and disable same thread checks, or use a file DB.

Update TestingConfig (in config.py), and import StaticPool:

+from sqlalchemy.pool import StaticPool
+
 class TestingConfig(Config):
     TESTING = True
-    SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
+    SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
+    SQLALCHEMY_ENGINE_OPTIONS = {
+        "poolclass": StaticPool,
+        "connect_args": {"check_same_thread": False},
+    }

Alternatively: use a temp file URI (e.g., sqlite:///test.db) and remove StaticPool.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
app = create_app(TestingConfig)
# setup
app_context = app.app_context()
app_context.push()
db.create_all()
from sqlalchemy.pool import StaticPool
class TestingConfig(Config):
TESTING = True
SQLALCHEMY_DATABASE_URI = "sqlite:///:memory:"
SQLALCHEMY_ENGINE_OPTIONS = {
"poolclass": StaticPool,
"connect_args": {"check_same_thread": False},
}
🤖 Prompt for AI Agents
In tests/conftest.py around lines 10-16 the app uses an in-memory SQLite DB
which can reset across connections; update TestingConfig in config.py to use
StaticPool and disable same-thread checks by importing
sqlalchemy.pool.StaticPool and setting SQLALCHEMY_DATABASE_URI =
"sqlite:///:memory:" plus SQLALCHEMY_ENGINE_OPTIONS = {'poolclass': StaticPool,
'connect_args': {'check_same_thread': False}} so the same in-memory DB is shared
across connections; alternatively change TestingConfig to use a temp file DB URI
like sqlite:///test.db and remove StaticPool.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SqlAlchemy with Sqlite by default uses SingletonThreadPool, which maintains a single connection per thread. My tests run in a single thread sequentially and each test creates and drops all tables and data as its setup to ensure test isolation. All my tests pass. So, I don't think StaticPool with "check_same_thread": False is needed currently.

With the type of tests that i am writing won't using a single connection across multiple threads be a problem, as i am checking the number of categories created, and i assume i am the only one that is creating it. what if another thread also creates categories and my assertion of number of categories created in the table fails?

What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piyush-jaiswal, you're absolutely correct! Your analysis is spot-on:

Your current setup is appropriate for sequential single-threaded tests:

  1. SingletonThreadPool is sufficient: Since pytest runs tests sequentially by default (single thread), the SingletonThreadPool maintains one connection throughout the test session, keeping the in-memory database intact.

  2. Test isolation: Your setup/teardown with db.create_all() and db.drop_all() in the app fixture ensures each test starts with a clean slate, which is the right approach.

  3. Your threading concern is valid: You're right that with StaticPool + check_same_thread: False, if tests were running in parallel threads, you could have race conditions where multiple threads create categories simultaneously, breaking your count assertions.

When StaticPool would be needed:

  • Running tests in parallel with pytest-xdist (pytest -n auto)
  • If your Flask app code opens multiple connections internally (rare with typical Flask patterns)
  • If you were using a worker pool or background threads in tests

Since all your tests pass and you're running sequentially, no changes are needed. Your current configuration is correct for your use case.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piyush-jaiswal, thanks for sharing that reference! It's a great resource to have on hand for understanding SQLite's pooling behavior in different scenarios. I've noted it for future reference.


✏️ Learnings added
Learnt from: piyush-jaiswal
PR: piyush-jaiswal/ecommerce-rest-api#27
File: tests/conftest.py:10-16
Timestamp: 2025-10-16T08:06:43.379Z
Learning: For SQLite in-memory database testing with Flask/SQLAlchemy in `tests/conftest.py`: SingletonThreadPool (SQLAlchemy's default for SQLite) is sufficient for sequential single-threaded tests. StaticPool with `check_same_thread: False` is only needed when running tests in parallel (e.g., with pytest-xdist) or when the app opens multiple connections. Reference: https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#threading-pooling-behavior

Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.

yield app

# teardown
db.session.remove()
db.drop_all()
app_context.pop()


@pytest.fixture
def client(app):
return app.test_client()


@pytest.fixture
Expand Down
29 changes: 11 additions & 18 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,19 @@ class TestAuth:
@pytest.fixture(autouse=True)
def setup(self, client):
self.client = client
with client.application.app_context():
assert User.query.count() == 0
assert User.query.count() == 0

def _verify_user_in_db(self, email, should_exist=True):
with self.client.application.app_context():
user = User.get(email=email)
if should_exist:
assert user is not None
assert user.email == email
return user
else:
assert user is None
user = User.get(email=email)
if should_exist:
assert user is not None
assert user.email == email
return user
else:
assert user is None

def _count_users(self):
with self.client.application.app_context():
return User.query.count()
return User.query.count()

def _test_invalid_request_data(self, endpoint, expected_status=422):
response = self.client.post(endpoint, json={})
Expand All @@ -47,9 +44,7 @@ def _test_invalid_request_data(self, endpoint, expected_status=422):
assert response.status_code == expected_status

def _decode_token(self, token):
# Needs Flask app context for secret/algorithms from current_app.config
with self.client.application.app_context():
return decode_token(token, allow_expired=False)
return decode_token(token, allow_expired=False)

def _assert_jwt_structure(self, token, expected_sub, expected_type, fresh=False):
assert token.count(".") == 2, f"Token does not have three segments: {token}"
Expand Down Expand Up @@ -169,8 +164,6 @@ def test_refresh_token_missing_auth(self):
utils.verify_token_error_response(response, "authorization_required")

def test_refresh_token_expired(self):
expired_headers = utils.get_expired_token_headers(
self.client.application.app_context()
)
expired_headers = utils.get_expired_token_headers()
response = self.client.post("/auth/refresh", headers=expired_headers)
utils.verify_token_error_response(response, "token_expired")
Loading