Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from currencies.models import Currency
from core.helpers import ensure_serializable
from dateutil.parser import parse
from sales_channels.integrations.amazon.factories.sales_channels.issues import FetchRemoteIssuesFactory
from sales_channels.integrations.amazon.factories.sales_channels.recently_synced_products import FetchRecentlySyncedProductFactory
import datetime
from imports_exports.helpers import append_broken_record, increment_processed_records
from sales_channels.integrations.amazon.models.imports import (
Expand Down Expand Up @@ -980,10 +980,11 @@ def process_product_item(self, product):
self.handle_gtin_exemption(instance, view)
self.handle_product_browse_node(instance, view)

FetchRemoteIssuesFactory(
FetchRecentlySyncedProductFactory(
remote_product=instance.remote_instance,
view=view,
response_data=product
response_data=product,
match_images=False,
).run()

product_obj = product_instance or instance.instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def update_assign_issues(
if self.remote_product.id != self.remote_instance.id:
self.remote_product = self.remote_instance


FetchRemoteValidationIssueFactory(
remote_product=self.remote_product,
view=self.view,
Expand Down Expand Up @@ -505,7 +504,6 @@ def update_product(
listings = ListingsApi(self._get_client())
response = listings.patch_listings_item(**self._build_listing_kwargs(sku, marketplace_id, body, force_validation_only))


if getattr(self, "remote_product", None):
self.remote_product.last_sync_at = timezone.now()
self.remote_product.save(update_fields=["last_sync_at"])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,82 +1,7 @@
from core.helpers import ensure_serializable
from sales_channels.integrations.amazon.factories.mixins import GetAmazonAPIMixin
from sales_channels.integrations.amazon.models import AmazonProductIssue


class FetchRemoteIssuesFactory(GetAmazonAPIMixin):
"""Fetch latest listing issues for a remote product in a marketplace."""

def __init__(self, *, remote_product, view, response_data=None):
self.remote_product = remote_product
self.view = view
self.sales_channel = view.sales_channel.get_real_instance()
self.response_data = response_data

def run(self):

if not self.remote_product or not getattr(self.remote_product, 'remote_sku', None):
return

AmazonProductIssue.objects.filter(
remote_product=self.remote_product,
view=self.view,
is_validation_issue=False,
).delete()

if self.response_data:
response = self.response_data
else:
response = self.get_listing_item(
self.remote_product.remote_sku,
self.view.remote_id,
included_data=["issues", "summaries"],
)

if isinstance(response, dict):
issues_data = response.get("issues", []) or []
summaries = response.get("summaries", []) or []
else:
issues_data = getattr(response, "issues", []) or []
summaries = getattr(response, "summaries", []) or []

if summaries:
summary = summaries[0]
asin = summary.get("asin") if isinstance(summary, dict) else getattr(summary, "asin", None)
if asin and getattr(self.remote_product, "local_instance", None):
from sales_channels.integrations.amazon.models import AmazonExternalProductId

product = self.remote_product.local_instance
try:
ext = AmazonExternalProductId.objects.get(product=product, view=self.view)
if ext.created_asin != asin:
ext.created_asin = asin
ext.save(update_fields=["created_asin"])
except AmazonExternalProductId.DoesNotExist:
AmazonExternalProductId.objects.create(
multi_tenant_company=self.remote_product.multi_tenant_company,
product=product,
view=self.view,
type=AmazonExternalProductId.TYPE_ASIN,
value=asin,
created_asin=asin,
)

for issue in issues_data:
data = ensure_serializable(
issue.to_dict() if hasattr(issue, "to_dict") else issue
)
AmazonProductIssue.objects.create(
multi_tenant_company=self.view.multi_tenant_company,
remote_product=self.remote_product,
view=self.view,
code=data.get("code"),
message=data.get("message"),
severity=data.get("severity"),
raw_data=data,
is_validation_issue=False,
)


class FetchRemoteValidationIssueFactory:
"""Persist validation issues returned from API submissions."""

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from core.helpers import ensure_serializable
from sales_channels.integrations.amazon.factories.mixins import GetAmazonAPIMixin
from sales_channels.integrations.amazon.models import AmazonProductIssue


class FetchRecentlySyncedProductFactory(GetAmazonAPIMixin):
"""Fetch issues, ASIN, and image mappings for a remote product."""

def __init__(self, *, remote_product, view, response_data=None, match_images=False):
self.remote_product = remote_product
self.view = view
self.sales_channel = view.sales_channel.get_real_instance()
self.response_data = response_data
self.match_images = match_images

def run(self):
if not self._is_valid_product():
return

self._clear_existing_issues()
response = self._get_response()
issues, summaries, attributes = self._extract_sections(response)
self._sync_asin(summaries)
self._persist_issues(issues)
if self.match_images:
self._match_images(attributes)

def _is_valid_product(self):
return self.remote_product and getattr(self.remote_product, "remote_sku", None)

def _clear_existing_issues(self):
AmazonProductIssue.objects.filter(
remote_product=self.remote_product,
view=self.view,
is_validation_issue=False,
).delete()

def _get_response(self):
if self.response_data:
return self.response_data
return self.get_listing_item(
self.remote_product.remote_sku,
self.view.remote_id,
included_data=["issues", "summaries", "attributes"],
)

def _extract_sections(self, response):
if isinstance(response, dict):
issues_data = response.get("issues", []) or []
summaries = response.get("summaries", []) or []
attributes = response.get("attributes", {}) or {}
else:
issues_data = getattr(response, "issues", []) or []
summaries = getattr(response, "summaries", []) or []
attributes = getattr(response, "attributes", {}) or {}
return issues_data, summaries, attributes

def _sync_asin(self, summaries):
if not summaries:
return
summary = summaries[0]
asin = summary.get("asin") if isinstance(summary, dict) else getattr(summary, "asin", None)
if asin and getattr(self.remote_product, "local_instance", None):
from sales_channels.integrations.amazon.models import AmazonExternalProductId

product = self.remote_product.local_instance
try:
ext = AmazonExternalProductId.objects.get(product=product, view=self.view)
if ext.created_asin != asin:
ext.created_asin = asin
ext.save(update_fields=["created_asin"])
except AmazonExternalProductId.DoesNotExist:
AmazonExternalProductId.objects.create(
multi_tenant_company=self.remote_product.multi_tenant_company,
product=product,
view=self.view,
type=AmazonExternalProductId.TYPE_ASIN,
value=asin,
created_asin=asin,
)

def _persist_issues(self, issues_data):
for issue in issues_data:
data = ensure_serializable(
issue.to_dict() if hasattr(issue, "to_dict") else issue
)
AmazonProductIssue.objects.create(
multi_tenant_company=self.view.multi_tenant_company,
remote_product=self.remote_product,
view=self.view,
code=data.get("code"),
message=data.get("message"),
severity=data.get("severity"),
raw_data=data,
is_validation_issue=False,
)

def _match_images(self, attributes):
if not self.remote_product.product_owner or not getattr(self.remote_product, "local_instance", None):
return

from media.models import Media, MediaProductThrough
from sales_channels.integrations.amazon.models import AmazonImageProductAssociation
from sales_channels.integrations.amazon.image_similarity import phash_is_same

image_keys = [
"main_product_image_locator",
*[f"other_product_image_locator_{i}" for i in range(1, 9)],
]
remote_urls = []
for key in image_keys:
val = attributes.get(key) if isinstance(attributes, dict) else getattr(attributes, key, None)
if not val:
continue
item = val[0] if isinstance(val, list) else val
url = item.get("media_location") if isinstance(item, dict) else getattr(item, "media_location", None)
if url:
remote_urls.append(url)

if not remote_urls:
return

throughs = (
MediaProductThrough.objects.filter(
product=self.remote_product.local_instance, media__type=Media.IMAGE
).order_by("sort_order")
)

for through, remote_url in zip(throughs, remote_urls):
local_path = getattr(through.media.image, "path", None)
if not local_path:
continue
try:
is_same = phash_is_same(local_path, remote_url, threshold=95.0)
except Exception:
continue
if not is_same:
continue
instance, _ = AmazonImageProductAssociation.objects.get_or_create(
multi_tenant_company=self.view.multi_tenant_company,
sales_channel=self.sales_channel,
local_instance=through,
remote_product=self.remote_product,
)
if instance.imported_url != remote_url:
instance.imported_url = remote_url
instance.save(update_fields=["imported_url"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from datetime import timedelta
from django.utils import timezone
from sales_channels.integrations.amazon.models import AmazonProduct, AmazonSalesChannelView
from sales_channels.integrations.amazon.factories.sales_channels.recently_synced_products import (
FetchRecentlySyncedProductFactory,
)


def refresh_recent_amazon_products_flow():
"""Refresh data for Amazon products synced in the last 15 minutes."""
cutoff = timezone.now() - timedelta(minutes=15)
products = AmazonProduct.objects.filter(last_sync_at__gte=cutoff)
for product in products.iterator():
marketplaces = product.created_marketplaces or []
views = AmazonSalesChannelView.objects.filter(
sales_channel=product.sales_channel, remote_id__in=marketplaces
)
for view in views:
fac = FetchRecentlySyncedProductFactory(
remote_product=product, view=view, match_images=True
)
fac.run()
46 changes: 46 additions & 0 deletions OneSila/sales_channels/integrations/amazon/image_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import io
import math
from typing import Tuple
import requests
from PIL import Image
import imagehash

DEFAULT_HEADERS = {"User-Agent": "img-similarity/1.0"}
MAX_BYTES = 25 * 1024 * 1024 # 25 MB safety cap


def _fetch_bytes(url: str, timeout: Tuple[float, float] = (5, 20)) -> bytes:
"""Download URL into memory with a size cap."""
with requests.get(url, headers=DEFAULT_HEADERS, stream=True, timeout=timeout) as r:
r.raise_for_status()
total = 0
chunks = []
for chunk in r.iter_content(1024 * 32):
if not chunk:
break
total += len(chunk)
Comment on lines +12 to +21
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: No retry logic for transient network errors when fetching images.

Adding retry logic will help prevent missed matches due to temporary network issues. You can implement a retry loop or use a library with built-in retry support to improve reliability.

Suggested implementation:

import math
from typing import Tuple
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from PIL import Image
import imagehash
def _fetch_bytes(url: str, timeout: Tuple[float, float] = (5, 20)) -> bytes:
    """Download URL into memory with a size cap and retry logic for transient errors."""
    session = requests.Session()
    retries = Retry(
        total=3,
        backoff_factor=0.5,
        status_forcelist=[500, 502, 503, 504],
        allowed_methods=["GET"],
        raise_on_status=False,
    )
    adapter = HTTPAdapter(max_retries=retries)
    session.mount("http://", adapter)
    session.mount("https://", adapter)

    with session.get(url, headers=DEFAULT_HEADERS, stream=True, timeout=timeout) as r:
        r.raise_for_status()
        total = 0
        chunks = []
        for chunk in r.iter_content(1024 * 32):
            if not chunk:
                break
            total += len(chunk)
            if total > MAX_BYTES:
                raise ValueError(f"Image too large (> {MAX_BYTES} bytes): {url}")
            chunks.append(chunk)
        return b"".join(chunks)

if total > MAX_BYTES:
raise ValueError(f"Image too large (> {MAX_BYTES} bytes): {url}")
chunks.append(chunk)
return b"".join(chunks)


def _pil_from_source(src: str) -> Image.Image:
if src.startswith("http://") or src.startswith("https://"):
data = _fetch_bytes(src)
img = Image.open(io.BytesIO(data))
else:
img = Image.open(src)
img.load()
return img
Comment on lines +28 to +35
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): No validation of image format or content after loading.

Consider adding checks to ensure the image is valid and in a supported format to prevent exceptions from corrupt or unsupported files.

Suggested change
def _pil_from_source(src: str) -> Image.Image:
if src.startswith("http://") or src.startswith("https://"):
data = _fetch_bytes(src)
img = Image.open(io.BytesIO(data))
else:
img = Image.open(src)
img.load()
return img
def _pil_from_source(src: str) -> Image.Image:
SUPPORTED_FORMATS = {"JPEG", "PNG", "BMP", "GIF", "WEBP"}
try:
if src.startswith("http://") or src.startswith("https://"):
data = _fetch_bytes(src)
img = Image.open(io.BytesIO(data))
else:
img = Image.open(src)
img.load()
except Exception as e:
raise ValueError(f"Failed to load image from source '{src}': {e}")
if not hasattr(img, "format") or img.format is None:
raise ValueError(f"Image format could not be determined for source '{src}'.")
if img.format.upper() not in SUPPORTED_FORMATS:
raise ValueError(f"Unsupported image format '{img.format}' for source '{src}'. Supported formats: {', '.join(SUPPORTED_FORMATS)}.")
return img



def phash_is_same(src1: str, src2: str, hash_size: int = 16, threshold: float = 95.0) -> bool:
img1 = _pil_from_source(src1).convert("RGB")
img2 = _pil_from_source(src2).convert("RGB")
h1 = imagehash.phash(img1, hash_size=hash_size)
h2 = imagehash.phash(img2, hash_size=hash_size)
dist = int(h1 - h2)
max_bits = hash_size * hash_size
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (code-quality): Replace x * x with x ** 2 (square-identity)

Suggested change
max_bits = hash_size * hash_size
max_bits = hash_size**2

allowed_dist = math.floor((1.0 - threshold / 100.0) * max_bits)
return dist <= allowed_dist
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def refresh_amazon_latest_issues(
AmazonProduct,
AmazonSalesChannelView,
)
from sales_channels.integrations.amazon.factories.sales_channels.issues import (
FetchRemoteIssuesFactory,
from sales_channels.integrations.amazon.factories.sales_channels.recently_synced_products import (
FetchRecentlySyncedProductFactory,
)

multi_tenant_company = get_multi_tenant_company(info, fail_silently=False)
Expand All @@ -173,7 +173,7 @@ def refresh_amazon_latest_issues(
sales_channel__multi_tenant_company=multi_tenant_company,
)

factory = FetchRemoteIssuesFactory(
factory = FetchRecentlySyncedProductFactory(
remote_product=remote_product,
view=view,
)
Expand Down
24 changes: 6 additions & 18 deletions OneSila/sales_channels/integrations/amazon/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,24 +166,12 @@ def actual_task():
task.execute(actual_task)


@db_periodic_task(crontab(minute='0', hour='0,12'))
def refresh_amazon_product_issues_cronjob():
"""Fetch latest listing issues for Amazon products synced in the last 12 hours."""
from datetime import timedelta
from django.utils import timezone
from .models import AmazonProduct, AmazonSalesChannelView
from .factories.sales_channels.issues import FetchRemoteIssuesFactory

cutoff = timezone.now() - timedelta(hours=12)
products = AmazonProduct.objects.filter(last_sync_at__gte=cutoff)
for product in products.iterator():
marketplaces = product.created_marketplaces or []
views = AmazonSalesChannelView.objects.filter(
sales_channel=product.sales_channel, remote_id__in=marketplaces
)
for view in views:
fac = FetchRemoteIssuesFactory(remote_product=product, view=view)
fac.run()
@db_periodic_task(crontab(minute='*/15'))
def refresh_recent_amazon_products_cronjob():
"""Refresh data for Amazon products synced in the last 15 minutes."""
from .flows.recently_synced_products import refresh_recent_amazon_products_flow

refresh_recent_amazon_products_flow()


@db_periodic_task(crontab(minute='0', hour='0', day='1'))
Expand Down
Loading
Loading