diff --git a/test/nn/test_sequential.py b/test/nn/test_sequential.py index 6239aa29fc4d..ba2bbd3b4f41 100644 --- a/test/nn/test_sequential.py +++ b/test/nn/test_sequential.py @@ -174,3 +174,31 @@ def test_sequential_to_hetero(): assert isinstance(out_dict, dict) and len(out_dict) == 2 assert out_dict['paper'].size() == (100, 64) assert out_dict['author'].size() == (100, 64) + + +def test_sequential_no_double_execution(): + # Test for issue #10393: Sequential should not cause double execution + # when initialized from a script with a valid Python identifier name + execution_count = [0] + + def increment_counter(x): + execution_count[0] += 1 + return x + + x = torch.randn(4, 16) + + # Create Sequential model - should not execute forward pass + model = Sequential('x', [(increment_counter, 'x -> y')]) + + # Counter should not be incremented during initialization + assert execution_count[0] == 0 + + # Execute forward pass + output = model(x) + assert execution_count[0] == 1 + assert output.shape == x.shape + + # Execute again to verify it works correctly + output = model(x) + assert execution_count[0] == 2 + assert output.shape == x.shape diff --git a/torch_geometric/nn/sequential.py b/torch_geometric/nn/sequential.py index 1cafc0b06dc1..c0334bf0e0ee 100644 --- a/torch_geometric/nn/sequential.py +++ b/torch_geometric/nn/sequential.py @@ -246,12 +246,21 @@ def _set_jittable_template(self, raise_on_error: bool = False) -> None: root_dir = osp.dirname(osp.realpath(__file__)) uid = '%06x' % random.randrange(16**6) jinja_prefix = f'{self.__module__}_{self.__class__.__name__}_{uid}' + + # Filter out modules that would cause re-execution: + # - '__main__' is the script being executed + # - Modules not in sys.modules would be imported first time + modules_to_import = [] + if (self._caller_module != '__main__' + and self._caller_module in sys.modules): + modules_to_import.append(self._caller_module) + module = module_from_template( module_name=jinja_prefix, template_path=osp.join(root_dir, 'sequential.jinja'), tmp_dirname='sequential', # Keyword arguments: - modules=[self._caller_module], + modules=modules_to_import, signature=self.signature, children=self._children, )