@@ -85,15 +85,17 @@ type (
8585 hostEnv * hostEnvImpl
8686 }
8787
88+ activityProvider func (name string ) activity
8889 // activityTaskHandlerImpl is the implementation of ActivityTaskHandler
8990 activityTaskHandlerImpl struct {
90- taskListName string
91- identity string
92- service m.TChanWorkflowService
93- metricsScope tally.Scope
94- logger * zap.Logger
95- userContext context.Context
96- hostEnv * hostEnvImpl
91+ taskListName string
92+ identity string
93+ service m.TChanWorkflowService
94+ metricsScope tally.Scope
95+ logger * zap.Logger
96+ userContext context.Context
97+ hostEnv * hostEnvImpl
98+ activityProvider activityProvider
9799 }
98100
99101 // history wrapper method to help information about events.
@@ -809,15 +811,25 @@ func newActivityTaskHandler(
809811 service m.TChanWorkflowService ,
810812 params workerExecutionParameters ,
811813 env * hostEnvImpl ,
814+ ) ActivityTaskHandler {
815+ return newActivityTaskHandlerWithCustomProvider (service , params , env , nil )
816+ }
817+
818+ func newActivityTaskHandlerWithCustomProvider (
819+ service m.TChanWorkflowService ,
820+ params workerExecutionParameters ,
821+ env * hostEnvImpl ,
822+ activityProvider activityProvider ,
812823) ActivityTaskHandler {
813824 return & activityTaskHandlerImpl {
814- taskListName : params .TaskList ,
815- identity : params .Identity ,
816- service : service ,
817- logger : params .Logger ,
818- metricsScope : params .MetricsScope ,
819- userContext : params .UserContext ,
820- hostEnv : env ,
825+ taskListName : params .TaskList ,
826+ identity : params .Identity ,
827+ service : service ,
828+ logger : params .Logger ,
829+ metricsScope : params .MetricsScope ,
830+ userContext : params .UserContext ,
831+ hostEnv : env ,
832+ activityProvider : activityProvider ,
821833 }
822834}
823835
@@ -958,10 +970,10 @@ func (ath *activityTaskHandlerImpl) Execute(t *s.PollForActivityTaskResponse) (r
958970 defer invoker .Close ()
959971 ctx := WithActivityTask (canCtx , t , invoker , ath .logger , ath .metricsScope )
960972 activityType := * t .GetActivityType ()
961- activityImplementation , ok := ath . hostEnv .getActivity (activityType .GetName ())
962- if ! ok {
973+ activityImplementation := ath .getActivity (activityType .GetName ())
974+ if activityImplementation == nil {
963975 // Couldn't find the activity implementation.
964- return nil , fmt .Errorf ("Unable to find activityType=%v" , activityType .GetName ())
976+ return nil , fmt .Errorf ("unable to find activityType=%v" , activityType .GetName ())
965977 }
966978
967979 // panic handler
@@ -999,6 +1011,18 @@ func (ath *activityTaskHandlerImpl) Execute(t *s.PollForActivityTaskResponse) (r
9991011 return convertActivityResultToRespondRequest (ath .identity , t .TaskToken , output , err ), nil
10001012}
10011013
1014+ func (ath * activityTaskHandlerImpl ) getActivity (name string ) activity {
1015+ if ath .activityProvider != nil {
1016+ return ath .activityProvider (name )
1017+ }
1018+
1019+ if a , ok := ath .hostEnv .getActivity (name ); ok {
1020+ return a
1021+ }
1022+
1023+ return nil
1024+ }
1025+
10021026func createNewDecision (decisionType s.DecisionType ) * s.Decision {
10031027 return & s.Decision {
10041028 DecisionType : common .DecisionTypePtr (decisionType ),
0 commit comments