11from __future__ import annotations
22
3- import inspect
3+ from types import ModuleType
44
55from celery import Celery
6- from celery .app .base import PendingConfiguration
76
87from pytest_celery .api .container import CeleryTestContainer
98from pytest_celery .vendors .worker .defaults import DEFAULT_WORKER_ENV
109from pytest_celery .vendors .worker .defaults import DEFAULT_WORKER_LOG_LEVEL
1110from pytest_celery .vendors .worker .defaults import DEFAULT_WORKER_NAME
1211from pytest_celery .vendors .worker .defaults import DEFAULT_WORKER_QUEUE
1312from pytest_celery .vendors .worker .defaults import DEFAULT_WORKER_VERSION
13+ from pytest_celery .vendors .worker .volume import WorkerInitialContent
1414
1515
1616class CeleryWorkerContainer (CeleryTestContainer ):
@@ -37,9 +37,17 @@ def worker_name(cls) -> str:
3737 def worker_queue (cls ) -> str :
3838 return DEFAULT_WORKER_QUEUE
3939
40+ @classmethod
41+ def app_module (cls ) -> ModuleType :
42+ from pytest_celery .vendors .worker import app
43+
44+ return app
45+
4046 @classmethod
4147 def tasks_modules (cls ) -> set :
42- return set ()
48+ from pytest_celery .vendors .worker import tasks
49+
50+ return {tasks }
4351
4452 @classmethod
4553 def signals_modules (cls ) -> set :
@@ -76,96 +84,24 @@ def env(cls, celery_worker_cluster_config: dict, initial: dict | None = None) ->
7684 @classmethod
7785 def initial_content (
7886 cls ,
79- worker_tasks : set ,
87+ worker_tasks : set | None = None ,
8088 worker_signals : set | None = None ,
8189 worker_app : Celery | None = None ,
90+ app_module : ModuleType | None = None ,
8291 ) -> dict :
83- from pytest_celery .vendors .worker import app as app_module
92+ if app_module is None :
93+ app_module = cls .app_module ()
8494
85- app_module_src = inspect .getsource (app_module )
95+ if worker_tasks is None :
96+ worker_tasks = cls .tasks_modules ()
8697
87- imports = dict ()
88- initial_content = cls . _initial_content_worker_tasks ( worker_tasks )
89- imports [ "tasks_imports" ] = initial_content . pop ( "tasks_imports" )
98+ content = WorkerInitialContent ()
99+ content . set_app_module ( app_module )
100+ content . add_modules ( "tasks" , worker_tasks )
90101 if worker_signals :
91- initial_content .update (cls ._initial_content_worker_signals (worker_signals ))
92- imports ["signals_imports" ] = initial_content .pop ("signals_imports" )
102+ content .add_modules ("signals" , worker_signals )
93103 if worker_app :
94- # Accessing the worker_app.conf.changes.data property will trigger the PendingConfiguration to be resolved
95- # and the changes will be applied to the worker_app.conf, so we make a clone app to avoid affecting the
96- # original app object.
97- app = Celery (worker_app .main )
98- app .conf = worker_app .conf
99- config_changes_from_defaults = app .conf .changes .copy ()
100- if isinstance (config_changes_from_defaults , PendingConfiguration ):
101- config_changes_from_defaults = config_changes_from_defaults .data .changes
102- if not isinstance (config_changes_from_defaults , dict ):
103- raise TypeError (f"Unexpected type for config_changes: { type (config_changes_from_defaults )} " )
104- del config_changes_from_defaults ["deprecated_settings" ]
105-
106- name_code = f'name = "{ worker_app .main } "'
107- else :
108- config_changes_from_defaults = {}
109- name_code = f'name = "{ cls .worker_name ()} "'
110-
111- imports_format = "{%s}" % "}{" .join (imports .keys ())
112- imports_format = imports_format .format (** imports )
113- app_module_src = app_module_src .replace ("{0}" , imports_format )
114-
115- app_module_src = app_module_src .replace ("{1}" , name_code )
116-
117- config_items = (f" { repr (key )} : { repr (value )} " for key , value in config_changes_from_defaults .items ())
118- config_code = (
119- "config_updates = {\n " + ",\n " .join (config_items ) + "\n }"
120- if config_changes_from_defaults
121- else "config_updates = {}"
122- )
123- app_module_src = app_module_src .replace ("{2}" , config_code )
124-
125- initial_content ["app.py" ] = app_module_src .encode ()
126- return initial_content
127-
128- @classmethod
129- def _initial_content_worker_tasks (cls , worker_tasks : set ) -> dict :
130- from pytest_celery .vendors .worker import tasks
131-
132- worker_tasks .add (tasks )
104+ content .set_app_name (worker_app .main )
105+ content .set_config_from_object (worker_app )
133106
134- import_string = ""
135-
136- for module in worker_tasks :
137- import_string += f"from { module .__name__ } import *\n "
138-
139- initial_content = {
140- "__init__.py" : b"" ,
141- "tasks_imports" : import_string ,
142- }
143- if worker_tasks :
144- default_worker_tasks_src = {
145- f"{ module .__name__ .replace ('.' , '/' )} .py" : inspect .getsource (module ).encode () for module in worker_tasks
146- }
147- initial_content .update (default_worker_tasks_src )
148- else :
149- print ("No tasks found" )
150- return initial_content
151-
152- @classmethod
153- def _initial_content_worker_signals (cls , worker_signals : set ) -> dict :
154- import_string = ""
155-
156- for module in worker_signals :
157- import_string += f"from { module .__name__ } import *\n "
158-
159- initial_content = {
160- "__init__.py" : b"" ,
161- "signals_imports" : import_string ,
162- }
163- if worker_signals :
164- default_worker_signals_src = {
165- f"{ module .__name__ .replace ('.' , '/' )} .py" : inspect .getsource (module ).encode ()
166- for module in worker_signals
167- }
168- initial_content .update (default_worker_signals_src )
169- else :
170- print ("No signals found" )
171- return initial_content
107+ return content .generate ()
0 commit comments