Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions ariadne_codegen/contrib/client_forward_refs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def _insert_import_statement_in_method(

import_class_name = import_class.name

# Skip calls like self.something(...) from custom operation methods
# (e.g. execute_custom_operation returns self.get_data(response)).
# 'self' is not a generated class to import.
if (
import_class_name == "self"
or import_class_name not in self.imported_classes
):
return

# We add the class to our set of imported in methods - these classes
# don't need to be imported at all in the global scope.
self.imported_in_method.add(import_class_name)
Expand Down
46 changes: 46 additions & 0 deletions tests/contrib/test_client_forward_refs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import ast

from graphql import build_schema

from ariadne_codegen.contrib.client_forward_refs import ClientForwardRefsPlugin


def test_plugin_skips_self_calls_from_custom_operation_methods():
"""Methods that return self.something(...) must not trigger import of 'self'.

With enable_custom_operations, the client has execute_custom_operation()
returning self.get_data(response), and query()/mutation() returning
await self.execute_custom_operation(...). The plugin must not treat
'self' as a generated class name (KeyError('self')).
"""

module = ast.parse(
"""
from .async_base_client import AsyncBaseClient
from .some_operation import SomeOperation

class Client(AsyncBaseClient):
def execute_custom_operation(self):
return self.get_data("response")

async def get_something(self):
from .some_operation import SomeOperation
return SomeOperation.model_validate({})
"""
)

schema = build_schema("type Query { x: Int }")
config = {"target_package_name": "test"}
plugin = ClientForwardRefsPlugin(schema, config)

updated = plugin.generate_client_module(module)

client_class = next(n for n in updated.body if isinstance(n, ast.ClassDef))
methods = {n.name: n for n in client_class.body if isinstance(n, ast.FunctionDef)}
exec_op = methods["execute_custom_operation"]
first_stmt = exec_op.body[0]
assert not (
isinstance(first_stmt, ast.ImportFrom)
and first_stmt.module
and "self" in (a.name for a in first_stmt.names)
), "Plugin must not add import for 'self'"
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[tool.ariadne-codegen]
schema_path = "schema.graphql"
include_comments = "none"
target_package_name = "example_client"
enable_custom_operations = true
plugins = ["ariadne_codegen.contrib.client_forward_refs.ClientForwardRefsPlugin"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
schema {
query: Query
mutation: Mutation
}

type Query {
products(channel: String, first: Int): ProductCountableConnection
app: App
productTypes: ProductTypeCountableConnection
translations(
"""
Return the elements in the list that come before the specified cursor.
"""
before: String

"""
Return the elements in the list that come after the specified cursor.
"""
after: String

"""
Retrieve the first n elements from the list. Note that the system only allows fetching a maximum of 100 objects in a single query.
"""
first: Int

"""
Retrieve the last n elements from the list. Note that the system only allows fetching a maximum of 100 objects in a single query.
"""
last: Int
): TranslatableItemConnection
store: Store!
}

type Mutation {
updateMetadata(
"""
ID or token (for Order and Checkout) of an object to update.
"""
id: ID!
): UpdateMetadata
}

type Product implements ObjectWithMetadata {
id: ID!
slug: String!
name: String!
}

type ProductCountableEdge {
node: Product!
cursor: String!
}

type ProductCountableConnection {
edges: [ProductCountableEdge!]!
pageInfo: PageInfo!
totalCount: Int
}

type App {
id: ID!
}

type ProductTypeCountableConnection {
pageInfo: PageInfo!
}

type PageInfo {
hasNextPage: Boolean!
hasPreviousPage: Boolean!
startCursor: String
endCursor: String
}

interface ObjectWithMetadata {
"""
list of private metadata items. Requires staff permissions to access.
"""
privateMetadata: [MetadataItem!]!

"""
A single key from private metadata. Requires staff permissions to access.

Tip: Use GraphQL aliases to fetch multiple keys.
"""
privateMetafield(key: String!): String

"""
list of public metadata items. Can be accessed without permissions.
"""
metadata: [MetadataItem!]!

"""
A single key from public metadata.

Tip: Use GraphQL aliases to fetch multiple keys.
"""
metafield(key: String!): String
}

type MetadataItem {
"""
Key of a metadata item.
"""
key: String!

"""
Value of a metadata item.
"""
value: String!
}

type UpdateMetadata {
metadataErrors: [MetadataError!]!
@deprecated(
reason: "This field will be removed in Saleor 4.0. Use `errors` field instead."
)
errors: [MetadataError!]!
item: ObjectWithMetadata
}
type MetadataError {
"""
Name of a field that caused the error. A value of `null` indicates that the error isn't associated with a particular field.
"""
field: String

"""
The error message.
"""
message: String

"""
The error code.
"""
code: MetadataErrorCode!
}

"""
An enumeration.
"""
enum MetadataErrorCode {
GRAPHQL_ERROR
INVALID
NOT_FOUND
REQUIRED
NOT_UPDATED
}

type TranslatableItemConnection {
"""
Pagination data for this connection.
"""
pageInfo: PageInfo!
edges: [TranslatableItemEdge!]!

"""
A total count of items in the collection.
"""
totalCount: Int
}

type TranslatableItemEdge {
"""
The item at the end of the edge.
"""
node: TranslatableItem!

"""
A cursor for use in pagination.
"""
cursor: String!
}

union TranslatableItem =
ProductTranslatableContent
| CollectionTranslatableContent

type ProductTranslatableContent @doc(category: "Products") {
"""
The ID of the product translatable content.
"""
id: ID!

"""
The ID of the product to translate.

Added in Saleor 3.14.
"""
productId: ID!

"""
SEO title to translate.
"""
seoTitle: String

"""
SEO description to translate.
"""
seoDescription: String

"""
Product's name to translate.
"""
name: String!

"""
Product's description to translate.

Rich text format. For reference see https://editorjs.io/
"""
description: JSONString
}

type CollectionTranslatableContent @doc(category: "Products") {
"""
The ID of the collection translatable content.
"""
id: ID!

"""
The ID of the collection to translate.

Added in Saleor 3.14.
"""
collectionId: ID!

"""
SEO title to translate.
"""
seoTitle: String

"""
SEO description to translate.
"""
seoDescription: String

"""
Collection's name to translate.
"""
name: String!

"""
Collection's description to translate.

Rich text format. For reference see https://editorjs.io/
"""
description: JSONString
}

scalar JSONString

type Store {
books(inStock: Boolean, category: String): BookShelf!
}

type BookShelf {
hasBooks: Boolean!
}
36 changes: 36 additions & 0 deletions tests/main/test_main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import os
from importlib.metadata import version
from pathlib import Path
Expand Down Expand Up @@ -256,6 +257,41 @@ def test_main_generates_correct_package(
assert_the_same_files_in_directories(package_path, expected_package_path)


@pytest.mark.parametrize(
"project_dir, package_name",
[
(
(
CLIENTS_PATH
/ "client_forward_refs_custom_operations"
/ "pyproject.toml",
(
CLIENTS_PATH
/ "client_forward_refs_custom_operations"
/ "schema.graphql",
),
),
"example_client",
),
],
indirect=["project_dir"],
)
def test_main_client_forward_refs_with_custom_operations(project_dir, package_name):
"""ClientForwardRefsPlugin + enable_custom_operations should produce valid client.

Custom operation methods (execute_custom_operation, query, mutation) return
self.* or await self.* - the plugin must not treat 'self' as a generated class.
"""
result = CliRunner().invoke(main)

assert result.exit_code == 0, result.output
package_path = project_dir / package_name
assert package_path.is_dir()
client_py = package_path / "client.py"
assert client_py.exists(), f"Expected {client_py}"
ast.parse(client_py.read_text())


@pytest.mark.parametrize(
"project_dir, expected_exception",
[
Expand Down