66package org .opensearch .ml .cluster ;
77
88import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_EXCLUDE_NODE_NAMES ;
9+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES ;
910import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_ONLY_RUN_ON_ML_NODE ;
11+ import static org .opensearch .ml .settings .MLCommonsSettings .ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES ;
1012
1113import java .util .ArrayList ;
12- import java .util .Arrays ;
1314import java .util .HashSet ;
1415import java .util .List ;
1516import java .util .Set ;
2122import org .opensearch .common .settings .Settings ;
2223import org .opensearch .core .common .Strings ;
2324import org .opensearch .ml .common .CommonValue ;
25+ import org .opensearch .ml .common .FunctionName ;
2426import org .opensearch .ml .utils .MLNodeUtils ;
2527
2628import lombok .extern .log4j .Log4j2 ;
@@ -31,6 +33,8 @@ public class DiscoveryNodeHelper {
3133 private final HotDataNodePredicate eligibleNodeFilter ;
3234 private volatile Boolean onlyRunOnMLNode ;
3335 private volatile Set <String > excludedNodeNames ;
36+ private volatile Set <String > remoteModelEligibleNodeRoles ;
37+ private volatile Set <String > localModelEligibleNodeRoles ;
3438
3539 public DiscoveryNodeHelper (ClusterService clusterService , Settings settings ) {
3640 this .clusterService = clusterService ;
@@ -41,44 +45,61 @@ public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) {
4145 clusterService
4246 .getClusterSettings ()
4347 .addSettingsUpdateConsumer (ML_COMMONS_EXCLUDE_NODE_NAMES , it -> excludedNodeNames = Strings .commaDelimitedListToSet (it ));
48+ remoteModelEligibleNodeRoles = new HashSet <>();
49+ remoteModelEligibleNodeRoles .addAll (ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES .get (settings ));
50+ clusterService .getClusterSettings ().addSettingsUpdateConsumer (ML_COMMONS_REMOTE_MODEL_ELIGIBLE_NODE_ROLES , it -> {
51+ remoteModelEligibleNodeRoles = new HashSet <>(it );
52+ });
53+ localModelEligibleNodeRoles = new HashSet <>();
54+ localModelEligibleNodeRoles .addAll (ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES .get (settings ));
55+ clusterService .getClusterSettings ().addSettingsUpdateConsumer (ML_COMMONS_LOCAL_MODEL_ELIGIBLE_NODE_ROLES , it -> {
56+ localModelEligibleNodeRoles = new HashSet <>(it );
57+ });
4458 }
4559
46- public String [] getEligibleNodeIds () {
47- DiscoveryNode [] nodes = getEligibleNodes ();
60+ public String [] getEligibleNodeIds (FunctionName functionName ) {
61+ DiscoveryNode [] nodes = getEligibleNodes (functionName );
4862 String [] nodeIds = new String [nodes .length ];
4963 for (int i = 0 ; i < nodes .length ; i ++) {
5064 nodeIds [i ] = nodes [i ].getId ();
5165 }
5266 return nodeIds ;
5367 }
5468
55- public DiscoveryNode [] getEligibleNodes () {
69+ public DiscoveryNode [] getEligibleNodes (FunctionName functionName ) {
5670 ClusterState state = this .clusterService .state ();
57- final List <DiscoveryNode > eligibleMLNodes = new ArrayList <>();
58- final List <DiscoveryNode > eligibleDataNodes = new ArrayList <>();
71+ final List <DiscoveryNode > eligibleNodes = new ArrayList <>();
5972 for (DiscoveryNode node : state .nodes ()) {
6073 if (excludedNodeNames != null && excludedNodeNames .contains (node .getName ())) {
6174 continue ;
6275 }
63- if (MLNodeUtils .isMLNode (node )) {
64- eligibleMLNodes .add (node );
65- }
66- if (!onlyRunOnMLNode && node .isDataNode () && isEligibleDataNode (node )) {
67- eligibleDataNodes .add (node );
76+ if (functionName == FunctionName .REMOTE ) {// remote model
77+ getEligibleNodes (remoteModelEligibleNodeRoles , eligibleNodes , node );
78+ } else { // local model
79+ if (onlyRunOnMLNode ) {
80+ if (MLNodeUtils .isMLNode (node )) {
81+ eligibleNodes .add (node );
82+ }
83+ } else {
84+ getEligibleNodes (localModelEligibleNodeRoles , eligibleNodes , node );
85+ }
6886 }
6987 }
70- if (eligibleMLNodes .size () > 0 ) {
71- DiscoveryNode [] mlNodes = eligibleMLNodes .toArray (new DiscoveryNode [0 ]);
72- log .debug ("Find {} dedicated ML nodes: {}" , eligibleMLNodes .size (), Arrays .toString (mlNodes ));
73- return mlNodes ;
74- } else {
75- DiscoveryNode [] dataNodes = eligibleDataNodes .toArray (new DiscoveryNode [0 ]);
76- log .debug ("Find no dedicated ML nodes. But have {} data nodes: {}" , eligibleDataNodes .size (), Arrays .toString (dataNodes ));
77- return dataNodes ;
88+ return eligibleNodes .toArray (new DiscoveryNode [0 ]);
89+ }
90+
91+ private void getEligibleNodes (Set <String > allowedNodeRoles , List <DiscoveryNode > eligibleNodes , DiscoveryNode node ) {
92+ if (allowedNodeRoles .contains ("data" ) && isEligibleDataNode (node )) {
93+ eligibleNodes .add (node );
94+ }
95+ for (String nodeRole : allowedNodeRoles ) {
96+ if (!"data" .equals (nodeRole ) && node .getRoles ().stream ().anyMatch (r -> r .roleName ().equals (nodeRole ))) {
97+ eligibleNodes .add (node );
98+ }
7899 }
79100 }
80101
81- public String [] filterEligibleNodes (String [] nodeIds ) {
102+ public String [] filterEligibleNodes (FunctionName functionName , String [] nodeIds ) {
82103 if (nodeIds == null || nodeIds .length == 0 ) {
83104 return nodeIds ;
84105 }
@@ -88,14 +109,30 @@ public String[] filterEligibleNodes(String[] nodeIds) {
88109 if (excludedNodeNames != null && excludedNodeNames .contains (node .getName ())) {
89110 continue ;
90111 }
91- if (MLNodeUtils .isMLNode (node )) {
92- eligibleNodes .add (node .getId ());
112+ if (functionName == FunctionName .REMOTE ) {// remote model
113+ getEligibleNodes (remoteModelEligibleNodeRoles , eligibleNodes , node );
114+ } else { // local model
115+ if (onlyRunOnMLNode ) {
116+ if (MLNodeUtils .isMLNode (node )) {
117+ eligibleNodes .add (node .getId ());
118+ }
119+ } else {
120+ getEligibleNodes (localModelEligibleNodeRoles , eligibleNodes , node );
121+ }
93122 }
94- if (!onlyRunOnMLNode && node .isDataNode () && isEligibleDataNode (node )) {
123+ }
124+ return eligibleNodes .toArray (new String [0 ]);
125+ }
126+
127+ private void getEligibleNodes (Set <String > allowedNodeRoles , Set <String > eligibleNodes , DiscoveryNode node ) {
128+ if (allowedNodeRoles .contains ("data" ) && isEligibleDataNode (node )) {
129+ eligibleNodes .add (node .getId ());
130+ }
131+ for (String nodeRole : allowedNodeRoles ) {
132+ if (!"data" .equals (nodeRole ) && node .getRoles ().stream ().anyMatch (r -> r .roleName ().equals (nodeRole ))) {
95133 eligibleNodes .add (node .getId ());
96134 }
97135 }
98- return eligibleNodes .toArray (new String [0 ]);
99136 }
100137
101138 public DiscoveryNode [] getAllNodes () {
0 commit comments