@@ -64,8 +64,14 @@ def gen_generator_func(
6464 setup_generator_class (builder )
6565 load_env_registers (builder )
6666 gen_arg_defaults (builder )
67- finalize_env_class (builder )
68- builder .add (Return (instantiate_generator_class (builder )))
67+ if builder .fn_info .can_merge_generator_and_env_classes ():
68+ gen = instantiate_generator_class (builder )
69+ builder .fn_info ._curr_env_reg = gen
70+ finalize_env_class (builder )
71+ else :
72+ finalize_env_class (builder )
73+ gen = instantiate_generator_class (builder )
74+ builder .add (Return (gen ))
6975
7076 args , _ , blocks , ret_type , fn_info = builder .leave ()
7177 func_ir , func_reg = gen_func_ir (args , blocks , fn_info )
@@ -122,30 +128,38 @@ def instantiate_generator_class(builder: IRBuilder) -> Value:
122128 fitem = builder .fn_info .fitem
123129 generator_reg = builder .add (Call (builder .fn_info .generator_class .ir .ctor , [], fitem .line ))
124130
125- # Get the current environment register. If the current function is nested, then the
126- # generator class gets instantiated from the callable class' '__call__' method, and hence
127- # we use the callable class' environment register. Otherwise, we use the original
128- # function's environment register.
129- if builder .fn_info .is_nested :
130- curr_env_reg = builder .fn_info .callable_class .curr_env_reg
131+ if builder .fn_info .can_merge_generator_and_env_classes ():
132+ # Set the generator instance to the initial state (zero).
133+ zero = Integer (0 )
134+ builder .add (SetAttr (generator_reg , NEXT_LABEL_ATTR_NAME , zero , fitem .line ))
131135 else :
132- curr_env_reg = builder .fn_info .curr_env_reg
133-
134- # Set the generator class' environment attribute to point at the environment class
135- # defined in the current scope.
136- builder .add (SetAttr (generator_reg , ENV_ATTR_NAME , curr_env_reg , fitem .line ))
137-
138- # Set the generator class' environment class' NEXT_LABEL_ATTR_NAME attribute to 0.
139- zero = Integer (0 )
140- builder .add (SetAttr (curr_env_reg , NEXT_LABEL_ATTR_NAME , zero , fitem .line ))
136+ # Get the current environment register. If the current function is nested, then the
137+ # generator class gets instantiated from the callable class' '__call__' method, and hence
138+ # we use the callable class' environment register. Otherwise, we use the original
139+ # function's environment register.
140+ if builder .fn_info .is_nested :
141+ curr_env_reg = builder .fn_info .callable_class .curr_env_reg
142+ else :
143+ curr_env_reg = builder .fn_info .curr_env_reg
144+
145+ # Set the generator class' environment attribute to point at the environment class
146+ # defined in the current scope.
147+ builder .add (SetAttr (generator_reg , ENV_ATTR_NAME , curr_env_reg , fitem .line ))
148+
149+ # Set the generator instance's environment to the initial state (zero).
150+ zero = Integer (0 )
151+ builder .add (SetAttr (curr_env_reg , NEXT_LABEL_ATTR_NAME , zero , fitem .line ))
141152 return generator_reg
142153
143154
144155def setup_generator_class (builder : IRBuilder ) -> ClassIR :
145156 name = f"{ builder .fn_info .namespaced_name ()} _gen"
146157
147158 generator_class_ir = ClassIR (name , builder .module_name , is_generated = True )
148- generator_class_ir .attributes [ENV_ATTR_NAME ] = RInstance (builder .fn_info .env_class )
159+ if builder .fn_info .can_merge_generator_and_env_classes ():
160+ builder .fn_info .env_class = generator_class_ir
161+ else :
162+ generator_class_ir .attributes [ENV_ATTR_NAME ] = RInstance (builder .fn_info .env_class )
149163 generator_class_ir .mro = [generator_class_ir ]
150164
151165 builder .classes .append (generator_class_ir )
@@ -392,7 +406,10 @@ def setup_env_for_generator_class(builder: IRBuilder) -> None:
392406 cls .send_arg_reg = exc_arg
393407
394408 cls .self_reg = builder .read (self_target , fitem .line )
395- cls .curr_env_reg = load_outer_env (builder , cls .self_reg , builder .symtables [- 1 ])
409+ if builder .fn_info .can_merge_generator_and_env_classes ():
410+ cls .curr_env_reg = cls .self_reg
411+ else :
412+ cls .curr_env_reg = load_outer_env (builder , cls .self_reg , builder .symtables [- 1 ])
396413
397414 # Define a variable representing the label to go to the next time
398415 # the '__next__' function of the generator is called, and add it
0 commit comments