|  | 
|  | 1 | +package threads | 
|  | 2 | + | 
|  | 3 | +import ( | 
|  | 4 | +	"context" | 
|  | 5 | +	"os" | 
|  | 6 | +	"os/signal" | 
|  | 7 | +	"sync" | 
|  | 8 | +	"sync/atomic" | 
|  | 9 | +	"syscall" | 
|  | 10 | + | 
|  | 11 | +	"github.com/openmcp-project/controller-utils/pkg/logging" | 
|  | 12 | +) | 
|  | 13 | + | 
|  | 14 | +var sigs chan os.Signal | 
|  | 15 | + | 
|  | 16 | +func init() { | 
|  | 17 | +	sigs = make(chan os.Signal, 1) | 
|  | 18 | +	signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) | 
|  | 19 | +} | 
|  | 20 | + | 
|  | 21 | +// WorkFunc is the function that holds the actual workload of a thread. | 
|  | 22 | +// The ThreadManager cancels the provided context when being stopped, so the workload should listen to the context's Done channel. | 
|  | 23 | +type WorkFunc func(context.Context) error | 
|  | 24 | + | 
|  | 25 | +// OnFinishFunc can be used to react to a thread finishing. | 
|  | 26 | +// Note that its context might already be cancelled (if the ThreadManager is being stopped). | 
|  | 27 | +type OnFinishFunc func(context.Context, ThreadReturn) | 
|  | 28 | + | 
|  | 29 | +// NewThreadManager creates a new ThreadManager. | 
|  | 30 | +// The mgrCtx is used for two purposes: | 
|  | 31 | +//  1. If the context is cancelled, the ThreadManager is stopped. Alternatively, its Stop() method can be called. | 
|  | 32 | +//  2. If the context contains a logger, it is used for logging. | 
|  | 33 | +// | 
|  | 34 | +// The threadCtx will be passed to the threads and is cancelled when the ThreadManager is stopped. | 
|  | 35 | +// If onFinish is not nil, it will be called whenever a thread finishes. It is called after the thread's own onFinish function, if any. | 
|  | 36 | +func NewThreadManager(mgrCtx, threadCtx context.Context, onFinish OnFinishFunc) *ThreadManager { | 
|  | 37 | +	threadCtx, stopFunc := context.WithCancel(threadCtx) | 
|  | 38 | +	return &ThreadManager{ | 
|  | 39 | +		threadCtx:   threadCtx, | 
|  | 40 | +		returns:     make(chan ThreadReturn, 100), | 
|  | 41 | +		onFinish:    onFinish, | 
|  | 42 | +		log:         logging.FromContextOrDiscard(mgrCtx), | 
|  | 43 | +		runOnStart:  []*Thread{}, | 
|  | 44 | +		ctxStop:     mgrCtx.Done(), | 
|  | 45 | +		stopThreads: stopFunc, | 
|  | 46 | +	} | 
|  | 47 | +} | 
|  | 48 | + | 
|  | 49 | +type ThreadManager struct { | 
|  | 50 | +	lock           sync.Mutex | 
|  | 51 | +	threadCtx      context.Context    // context that is passed to threads, is cancelled when the ThreadManager is stopped | 
|  | 52 | +	returns        chan ThreadReturn  // channel to receive thread returns | 
|  | 53 | +	onFinish       OnFinishFunc       // function to call when a thread finishes | 
|  | 54 | +	log            logging.Logger     // logger for the ThreadManager | 
|  | 55 | +	runOnStart     []*Thread          // is filled if threads are added before the ThreadManager is started | 
|  | 56 | +	ctxStop        <-chan struct{}    // channel to stop the ThreadManager | 
|  | 57 | +	stopThreads    context.CancelFunc // function to cancel the context that is passed to threads | 
|  | 58 | +	stopSelf       func()             // convenience function to use the internalStop channel | 
|  | 59 | +	internalStop   chan struct{}      // if the Stop() method is called, we need to stop the internal main loop by using this channel | 
|  | 60 | +	stopped        atomic.Bool        // indicates if the ThreadManager is stopped | 
|  | 61 | +	waitForThreads sync.WaitGroup     // used to wait for threads to finish when stopping the ThreadManager | 
|  | 62 | +} | 
|  | 63 | + | 
|  | 64 | +// Start starts the ThreadManager. | 
|  | 65 | +// This starts a goroutine that listens for thread returns and os signals. | 
|  | 66 | +// Calling Start() multiple times is a no-op, unless the ThreadManager has already been stopped, then it panics. | 
|  | 67 | +// It is possible to add threads before the ThreadManager is started, but they will only be run after Start() is called. | 
|  | 68 | +// Threads added after Start() will be run immediately. | 
|  | 69 | +// There are three ways to stop the ThreadManager again: | 
|  | 70 | +//  1. Cancel the context passed to the ThreadManager during creation. | 
|  | 71 | +//  2. Call the ThreadManager's Stop() method. | 
|  | 72 | +//  3. Send a SIGINT or SIGTERM signal to the process. | 
|  | 73 | +func (tm *ThreadManager) Start() { | 
|  | 74 | +	tm.lock.Lock() | 
|  | 75 | +	defer tm.lock.Unlock() | 
|  | 76 | +	if tm.stopped.Load() { | 
|  | 77 | +		panic("Start called on a stopped ThreadManager") | 
|  | 78 | +	} | 
|  | 79 | +	if tm.isStarted() { | 
|  | 80 | +		tm.log.Debug("Start called, but ThreadManager is already started, nothing to do") | 
|  | 81 | +		return | 
|  | 82 | +	} | 
|  | 83 | +	tm.log.Info("Starting ThreadManager") | 
|  | 84 | +	go func() { | 
|  | 85 | +		for { | 
|  | 86 | +			select { | 
|  | 87 | +			case tr, ok := <-tm.returns: | 
|  | 88 | +				if !ok { | 
|  | 89 | +					// channel has been closed, this means the Stop() method has been called | 
|  | 90 | +					return | 
|  | 91 | +				} | 
|  | 92 | +				if tr.Err != nil { | 
|  | 93 | +					tm.log.Error(tr.Err, "Error in thread", "thread", tr.Thread.id) | 
|  | 94 | +				} | 
|  | 95 | +			case sig := <-sigs: | 
|  | 96 | +				tm.log.Info("Received os signal, stopping ThreadManager", "signal", sig) | 
|  | 97 | +				tm.stop() | 
|  | 98 | +				return | 
|  | 99 | +			case <-tm.ctxStop: | 
|  | 100 | +				tm.stop() | 
|  | 101 | +				return | 
|  | 102 | +			} | 
|  | 103 | +		} | 
|  | 104 | +	}() | 
|  | 105 | +	runOnStart := tm.runOnStart | 
|  | 106 | +	tm.runOnStart = nil | 
|  | 107 | +	if len(runOnStart) > 0 { | 
|  | 108 | +		tm.log.Info("Running threads added before ThreadManager was started", "threadCount", len(runOnStart)) | 
|  | 109 | +		for _, t := range runOnStart { | 
|  | 110 | +			tm.run(t) | 
|  | 111 | +		} | 
|  | 112 | +	} | 
|  | 113 | +} | 
|  | 114 | + | 
|  | 115 | +// Stop stops the ThreadManager. | 
|  | 116 | +// Calling Stop() multiple times is a no-op. | 
|  | 117 | +// It is not possible to start the ThreadManager again after it has been stopped, a new instance must be created. | 
|  | 118 | +// Adding threads after the ThreadManager has been stopped is a no-op. | 
|  | 119 | +// The ThreadManager is also stopped when the context passed to the ThreadManager during creation is cancelled or when a SIGINT or SIGTERM signal is received. | 
|  | 120 | +func (tm *ThreadManager) Stop() { | 
|  | 121 | +	tm.lock.Lock() | 
|  | 122 | +	defer tm.lock.Unlock() | 
|  | 123 | +	if !tm.isStarted() { | 
|  | 124 | +		panic("Stop called on a ThreadManager that has not been started yet") | 
|  | 125 | +	} | 
|  | 126 | +	tm.stop() | 
|  | 127 | +} | 
|  | 128 | + | 
|  | 129 | +func (tm *ThreadManager) stop() { | 
|  | 130 | +	if tm.stopped.Load() { | 
|  | 131 | +		tm.log.Debug("Stop called, but ThreadManager is already stopped, nothing to do") | 
|  | 132 | +		return | 
|  | 133 | +	} | 
|  | 134 | +	tm.log.Info("Stopping ThreadManager, waiting for remaining threads to finish") | 
|  | 135 | +	tm.stopped.Store(true) | 
|  | 136 | +	tm.stopThreads() | 
|  | 137 | + | 
|  | 138 | +	tm.waitForThreads.Wait() | 
|  | 139 | +	close(tm.returns) | 
|  | 140 | +	tm.log.Info("ThreadManager stopped") | 
|  | 141 | +	return | 
|  | 142 | +} | 
|  | 143 | + | 
|  | 144 | +// Run gives a new thread to run to the ThreadManager. | 
|  | 145 | +// id is only used for logging and debugging purposes. | 
|  | 146 | +// work is the actual workload of the thread. | 
|  | 147 | +// onFinish can be used to react to the thread having finished. | 
|  | 148 | +// Note that there are some pre-defined functions that can be used as onFinish functions, e.g. the ThreadManager's Restart method. | 
|  | 149 | +func (tm *ThreadManager) Run(id string, work func(context.Context) error, onFinish OnFinishFunc) { | 
|  | 150 | +	tm.RunThread(NewThread(id, work, onFinish)) | 
|  | 151 | +} | 
|  | 152 | + | 
|  | 153 | +// RunThread is the same as Run, but takes a Thread struct instead of the individual parameters. | 
|  | 154 | +func (tm *ThreadManager) RunThread(t Thread) { | 
|  | 155 | +	tm.lock.Lock() | 
|  | 156 | +	defer tm.lock.Unlock() | 
|  | 157 | +	tm.run(&t) | 
|  | 158 | +} | 
|  | 159 | + | 
|  | 160 | +func (tm *ThreadManager) run(t *Thread) { | 
|  | 161 | +	if t == nil { | 
|  | 162 | +		tm.log.Error(nil, "run(t *Thread) called with nil Thread, this should never happen") | 
|  | 163 | +		return | 
|  | 164 | +	} | 
|  | 165 | +	if tm.stopped.Load() { | 
|  | 166 | +		tm.log.Info("Skipping thread run because ThreadManager is already stopped", "thread", t.id) | 
|  | 167 | +		return | 
|  | 168 | +	} | 
|  | 169 | +	if !tm.isStarted() { | 
|  | 170 | +		tm.runOnStart = append(tm.runOnStart, t) | 
|  | 171 | +		tm.log.Debug("ThreadManager has not been started yet, enqueuing thread to run on start", "thread", t.ID()) | 
|  | 172 | +		return | 
|  | 173 | +	} | 
|  | 174 | +	tm.log.Debug("Running thread", "thread", t.id) | 
|  | 175 | +	tm.waitForThreads.Add(1) | 
|  | 176 | +	go func() { | 
|  | 177 | +		defer tm.waitForThreads.Done() | 
|  | 178 | +		var err error | 
|  | 179 | +		if t.work != nil { | 
|  | 180 | +			err = t.work(tm.threadCtx) | 
|  | 181 | +		} else { | 
|  | 182 | +			tm.log.Debug("Thread has no work function", "thread", t.id) | 
|  | 183 | +		} | 
|  | 184 | +		tr := NewThreadReturn(t, err) | 
|  | 185 | +		if t.onFinish != nil { | 
|  | 186 | +			tm.log.Debug("Calling the thread's onFinish function", "thread", t.id) | 
|  | 187 | +			t.onFinish(tm.threadCtx, tr) | 
|  | 188 | +		} | 
|  | 189 | +		if tm.onFinish != nil { | 
|  | 190 | +			tm.log.Debug("Calling the thread manager's onFinish function", "thread", tr.Thread.id) | 
|  | 191 | +			tm.onFinish(tm.threadCtx, tr) | 
|  | 192 | +		} | 
|  | 193 | +		tm.returns <- tr | 
|  | 194 | +		tm.log.Debug("Thread finished", "thread", t.id) | 
|  | 195 | +	}() | 
|  | 196 | +} | 
|  | 197 | + | 
|  | 198 | +func (tm *ThreadManager) isStarted() bool { | 
|  | 199 | +	return tm.runOnStart == nil | 
|  | 200 | +} | 
|  | 201 | + | 
|  | 202 | +// IsStarted returns true if the ThreadManager has been started. | 
|  | 203 | +// Note that this will return true if the ThreadManager has been started at some point, even if it has been stopped by now. | 
|  | 204 | +func (tm *ThreadManager) IsStarted() bool { | 
|  | 205 | +	tm.lock.Lock() | 
|  | 206 | +	defer tm.lock.Unlock() | 
|  | 207 | +	return tm.isStarted() | 
|  | 208 | +} | 
|  | 209 | + | 
|  | 210 | +// IsStopped returns true if the ThreadManager has been stopped. | 
|  | 211 | +// Note that this will return false if the ThreadManager has not been started yet. | 
|  | 212 | +func (tm *ThreadManager) IsStopped() bool { | 
|  | 213 | +	return tm.stopped.Load() | 
|  | 214 | +} | 
|  | 215 | + | 
|  | 216 | +// IsRunning returns true if the ThreadManager is currently running, | 
|  | 217 | +// meaning it has been started and not yet been stopped. | 
|  | 218 | +// This is a convenience function that is equivalent to calling IsStarted() && !IsStopped(). | 
|  | 219 | +func (tm *ThreadManager) IsRunning() bool { | 
|  | 220 | +	return tm.IsStarted() && !tm.IsStopped() | 
|  | 221 | +} | 
|  | 222 | + | 
|  | 223 | +var _ OnFinishFunc = (*ThreadManager)(nil).Restart | 
|  | 224 | + | 
|  | 225 | +// Restart is a pre-defined onFinish function that can be used to restart a thread after it has finished. | 
|  | 226 | +// This method is not meant to be called directly, instead pass it to the ThreadManager's Run method as the onFinish parameter: | 
|  | 227 | +// | 
|  | 228 | +//	tm.Run("myThread", myWorkFunc, tm.Restart) | 
|  | 229 | +func (tm *ThreadManager) Restart(_ context.Context, tr ThreadReturn) { | 
|  | 230 | +	if tm.stopped.Load() { | 
|  | 231 | +		return | 
|  | 232 | +	} | 
|  | 233 | +	tm.RunThread(*tr.Thread) | 
|  | 234 | +} | 
|  | 235 | + | 
|  | 236 | +var _ OnFinishFunc = (*ThreadManager)(nil).RestartOnError | 
|  | 237 | + | 
|  | 238 | +// RestartOnError is a pre-defined onFinish function that can be used to restart a thread after it has finished, if it finished with an error. | 
|  | 239 | +// It is the opposite of RestartOnSuccess. | 
|  | 240 | +// This method is not meant to be called directly, instead pass it to the ThreadManager's Run method as the onFinish parameter: | 
|  | 241 | +// | 
|  | 242 | +//	tm.Run("myThread", myWorkFunc, tm.RestartOnError) | 
|  | 243 | +func (tm *ThreadManager) RestartOnError(_ context.Context, tr ThreadReturn) { | 
|  | 244 | +	if tr.Err != nil { | 
|  | 245 | +		tm.Restart(tm.threadCtx, tr) | 
|  | 246 | +	} | 
|  | 247 | +} | 
|  | 248 | + | 
|  | 249 | +var _ OnFinishFunc = (*ThreadManager)(nil).RestartOnSuccess | 
|  | 250 | + | 
|  | 251 | +// RestartOnSuccess is a pre-defined onFinish function that can be used to restart a thread after it has finished, if it didn't throw an error. | 
|  | 252 | +// It is the opposite of RestartOnError. | 
|  | 253 | +// This method is not meant to be called directly, instead pass it to the ThreadManager's Run method as the onFinish parameter: | 
|  | 254 | +// | 
|  | 255 | +//	tm.Run("myThread", myWorkFunc, tm.RestartOnSuccess) | 
|  | 256 | +func (tm *ThreadManager) RestartOnSuccess(_ context.Context, tr ThreadReturn) { | 
|  | 257 | +	if tr.Err == nil { | 
|  | 258 | +		tm.Restart(tm.threadCtx, tr) | 
|  | 259 | +	} | 
|  | 260 | +} | 
|  | 261 | + | 
|  | 262 | +// NewThread creates a new thread with the given id, work function and onFinish function. | 
|  | 263 | +// It is usually not required to call this function directly, instead use the ThreadManager's Run method. | 
|  | 264 | +// The Thread's fields are considered immutable after creation. | 
|  | 265 | +func NewThread(id string, work WorkFunc, onFinish OnFinishFunc) Thread { | 
|  | 266 | +	return Thread{ | 
|  | 267 | +		id:       id, | 
|  | 268 | +		work:     work, | 
|  | 269 | +		onFinish: onFinish, | 
|  | 270 | +	} | 
|  | 271 | +} | 
|  | 272 | + | 
|  | 273 | +// Thread represents a thread that can be run by the ThreadManager. | 
|  | 274 | +type Thread struct { | 
|  | 275 | +	id       string | 
|  | 276 | +	work     WorkFunc | 
|  | 277 | +	onFinish OnFinishFunc | 
|  | 278 | +} | 
|  | 279 | + | 
|  | 280 | +// ID returns the id of the thread. | 
|  | 281 | +func (t *Thread) ID() string { | 
|  | 282 | +	return t.id | 
|  | 283 | +} | 
|  | 284 | + | 
|  | 285 | +// WorkFunc returns the workload function of the thread. | 
|  | 286 | +func (t *Thread) WorkFunc() WorkFunc { | 
|  | 287 | +	return t.work | 
|  | 288 | +} | 
|  | 289 | + | 
|  | 290 | +// OnFinishFunc returns the onFinish function of the thread. | 
|  | 291 | +func (t *Thread) OnFinishFunc() OnFinishFunc { | 
|  | 292 | +	return t.onFinish | 
|  | 293 | +} | 
|  | 294 | + | 
|  | 295 | +// NewThreadReturn constructs a new ThreadReturn object. | 
|  | 296 | +// This is used by the ThreadManager internally and it should rarely be necessary to call this function directly. | 
|  | 297 | +func NewThreadReturn(thread *Thread, err error) ThreadReturn { | 
|  | 298 | +	return ThreadReturn{ | 
|  | 299 | +		Err:    err, | 
|  | 300 | +		Thread: thread, | 
|  | 301 | +	} | 
|  | 302 | +} | 
|  | 303 | + | 
|  | 304 | +// ThreadReturn represents the result of a thread's execution. | 
|  | 305 | +// It contains a reference to the thread and an error, if any occurred. | 
|  | 306 | +type ThreadReturn struct { | 
|  | 307 | +	Err    error | 
|  | 308 | +	Thread *Thread | 
|  | 309 | +} | 
0 commit comments