diff --git a/pkg/module_manager/models/modules/basic.go b/pkg/module_manager/models/modules/basic.go index dc829fe6..a78c2fea 100644 --- a/pkg/module_manager/models/modules/basic.go +++ b/pkg/module_manager/models/modules/basic.go @@ -238,6 +238,7 @@ func (bm *BasicModule) ResetState() { bm.l.Unlock() } +// RegisterHooks searches and registers all module hooks from a filesystem or GoHook Registry // RegisterHooks searches and registers all module hooks from a filesystem or GoHook Registry func (bm *BasicModule) RegisterHooks(logger *log.Logger) ([]*hooks.ModuleHook, error) { if bm.hooks.registered { @@ -245,14 +246,14 @@ func (bm *BasicModule) RegisterHooks(logger *log.Logger) ([]*hooks.ModuleHook, e return nil, nil } - hks, err := bm.searchModuleHooks() + searchModuleHooksResult, err := bm.searchModuleHooks() if err != nil { return nil, fmt.Errorf("search module hooks failed: %w", err) } - logger.Debug("Found hooks", slog.Int("count", len(hks))) + logger.Debug("Found hooks", slog.Int("count", len(searchModuleHooksResult.Hooks))) if logger.GetLevel() == log.LevelDebug { - for _, h := range hks { + for _, h := range searchModuleHooksResult.Hooks { logger.Debug("ModuleHook", slog.String("name", h.GetName()), slog.String("path", h.GetPath())) @@ -261,16 +262,22 @@ func (bm *BasicModule) RegisterHooks(logger *log.Logger) ([]*hooks.ModuleHook, e logger.Debug("Register hooks") - if err := bm.registerHooks(hks, logger); err != nil { + if err := bm.registerHooks(searchModuleHooksResult.Hooks, logger); err != nil { return nil, fmt.Errorf("register hooks: %w", err) } bm.hooks.registered = true + bm.hasReadiness = searchModuleHooksResult.HasReadiness - return hks, nil + return searchModuleHooksResult.Hooks, nil +} + +type searchModuleHooksResult struct { + Hooks []*hooks.ModuleHook + HasReadiness bool } -func (bm *BasicModule) searchModuleHooks() ([]*hooks.ModuleHook, error) { +func (bm *BasicModule) searchModuleHooks() (*searchModuleHooksResult, error) { shellHooks, err := bm.searchModuleShellHooks() if err != nil { return nil, fmt.Errorf("search module shell hooks: %w", err) @@ -278,39 +285,42 @@ func (bm *BasicModule) searchModuleHooks() ([]*hooks.ModuleHook, error) { goHooks := bm.searchModuleGoHooks() - batchHooks, err := bm.searchModuleBatchHooks() + batchHooksResult, err := bm.searchModuleBatchHooks() if err != nil { return nil, fmt.Errorf("search module batch hooks: %w", err) } - if len(shellHooks)+len(batchHooks) > 0 { + if len(shellHooks)+len(batchHooksResult.Hooks) > 0 { if err := bm.AssembleEnvironmentForModule(environmentmanager.ShellHookEnvironment); err != nil { return nil, fmt.Errorf("Assemble %q module's environment: %w", bm.GetName(), err) } } - mHooks := make([]*hooks.ModuleHook, 0, len(shellHooks)+len(goHooks)) + result := &searchModuleHooksResult{ + Hooks: make([]*hooks.ModuleHook, 0, len(shellHooks)+len(goHooks)+len(batchHooksResult.Hooks)), + HasReadiness: batchHooksResult.HasReadiness, + } for _, sh := range shellHooks { mh := hooks.NewModuleHook(sh) - mHooks = append(mHooks, mh) + result.Hooks = append(result.Hooks, mh) } for _, gh := range goHooks { mh := hooks.NewModuleHook(gh) - mHooks = append(mHooks, mh) + result.Hooks = append(result.Hooks, mh) } - for _, bh := range batchHooks { + for _, bh := range batchHooksResult.Hooks { mh := hooks.NewModuleHook(bh) - mHooks = append(mHooks, mh) + result.Hooks = append(result.Hooks, mh) } - sort.SliceStable(mHooks, func(i, j int) bool { - return mHooks[i].GetPath() < mHooks[j].GetPath() + sort.SliceStable(result.Hooks, func(i, j int) bool { + return result.Hooks[i].GetPath() < result.Hooks[j].GetPath() }) - return mHooks, nil + return result, nil } func (bm *BasicModule) searchModuleShellHooks() ([]*kind.ShellHook, error) { @@ -372,10 +382,21 @@ func (bm *BasicModule) searchModuleShellHooks() ([]*kind.ShellHook, error) { return hks, nil } -func (bm *BasicModule) searchModuleBatchHooks() ([]*kind.BatchHook, error) { +// searchModuleBatchHooks searches for batch hooks and returns them along with hasReadiness flag. +// This function has no side effects - it doesn't modify bm.hasReadiness directly. +// The caller (RegisterHooks) is responsible for setting bm.hasReadiness after successful registration. +type searchModuleBatchHooksResult struct { + Hooks []*kind.BatchHook + HasReadiness bool +} + +// searchModuleBatchHooks searches for batch hooks and returns them along with hasReadiness flag. +// This function has no side effects - it doesn't modify bm.hasReadiness directly. +// The caller (RegisterHooks) is responsible for setting bm.hasReadiness after successful registration. +func (bm *BasicModule) searchModuleBatchHooks() (*searchModuleBatchHooksResult, error) { hooksDir := filepath.Join(bm.Path, "hooks") if _, err := os.Stat(hooksDir); os.IsNotExist(err) { - return nil, nil + return &searchModuleBatchHooksResult{}, nil } hooksRelativePaths, err := RecursiveGetBatchHookExecutablePaths(bm.safeName(), hooksDir, bm.logger, hooksExcludedDir...) @@ -383,7 +404,9 @@ func (bm *BasicModule) searchModuleBatchHooks() ([]*kind.BatchHook, error) { return nil, err } - hks := make([]*kind.BatchHook, 0) + result := &searchModuleBatchHooksResult{ + Hooks: make([]*kind.BatchHook, 0, len(hooksRelativePaths)), + } // sort hooks by path sort.Strings(hooksRelativePaths) @@ -401,27 +424,27 @@ func (bm *BasicModule) searchModuleBatchHooks() ([]*kind.BatchHook, error) { } if sdkcfgs.Readiness != nil { - if bm.hasReadiness { + if result.HasReadiness { return nil, fmt.Errorf("multiple readiness hooks found in %s", hookPath) } - bm.hasReadiness = true + result.HasReadiness = true // add readiness hook nestedHookName := fmt.Sprintf("%s-readiness", hookName) shHook := kind.NewBatchHook(nestedHookName, hookPath, bm.safeName(), kind.BatchHookReadyKey, bm.keepTemporaryHookFiles, shapp.LogProxyHookJSON, bm.logger.Named("batch-hook")) - hks = append(hks, shHook) + result.Hooks = append(result.Hooks, shHook) } for key, cfg := range sdkcfgs.Hooks { nestedHookName := fmt.Sprintf("%s:%s:%s", hookName, cfg.Metadata.Name, key) shHook := kind.NewBatchHook(nestedHookName, hookPath, bm.safeName(), key, bm.keepTemporaryHookFiles, shapp.LogProxyHookJSON, bm.logger.Named("batch-hook")) - hks = append(hks, shHook) + result.Hooks = append(result.Hooks, shHook) } } - return hks, nil + return result, nil } func RecursiveGetBatchHookExecutablePaths(moduleName, dir string, logger *log.Logger, excludedDirs ...string) ([]string, error) { diff --git a/pkg/module_manager/models/modules/basic_test.go b/pkg/module_manager/models/modules/basic_test.go index 9a282f70..965d3927 100644 --- a/pkg/module_manager/models/modules/basic_test.go +++ b/pkg/module_manager/models/modules/basic_test.go @@ -309,3 +309,27 @@ exit 0 require.NotNil(t, erGetter) require.Equal(t, "Kubernetes version is too low", *erGetter) } + +// TestHasReadinessNotSetOnFailedRegistration tests that hasReadiness is only set +// after successful hook registration. This prevents stale hasReadiness state when +// registration is retried after an error (e.g., in AssembleEnvironmentForModule). +func TestHasReadinessNotSetOnFailedRegistration(t *testing.T) { + tmpModuleDir := t.TempDir() + + bm, err := NewBasicModule("test-readiness-registration", tmpModuleDir, 1, utils.Values{}, nil, nil) + require.NoError(t, err) + + logger := log.NewLogger() + bm.WithLogger(logger) + bm.WithDependencies(stubDeps(logger)) + + // Simulate stale state from a previous failed registration attempt + bm.hasReadiness = true + + // RegisterHooks should overwrite hasReadiness based on actual search results + _, err = bm.RegisterHooks(logger) + require.NoError(t, err) + + // hasReadiness should be false (no hooks in empty module) + require.False(t, bm.hasReadiness, "hasReadiness should reflect actual hooks found") +}