1515
1616import java .io .IOException ;
1717import java .util .Arrays ;
18+ import java .util .HashMap ;
1819import java .util .List ;
20+ import java .util .Map ;
1921import java .util .Optional ;
2022import java .util .stream .Collectors ;
2123
2830import org .opensearch .common .xcontent .XContentBuilder ;
2931import org .opensearch .common .xcontent .XContentParser ;
3032import org .opensearch .ml .action .profile .MLProfileAction ;
33+ import org .opensearch .ml .action .profile .MLProfileModelResponse ;
3134import org .opensearch .ml .action .profile .MLProfileNodeResponse ;
3235import org .opensearch .ml .action .profile .MLProfileRequest ;
36+ import org .opensearch .ml .common .MLTask ;
37+ import org .opensearch .ml .profile .MLModelProfile ;
3338import org .opensearch .ml .profile .MLProfileInput ;
39+ import org .opensearch .ml .utils .RestActionUtils ;
3440import org .opensearch .rest .BaseRestHandler ;
3541import org .opensearch .rest .BytesRestResponse ;
3642import org .opensearch .rest .RestRequest ;
3743import org .opensearch .rest .RestStatus ;
3844
3945import com .google .common .collect .ImmutableList ;
46+ import com .google .common .collect .ImmutableMap ;
4047
4148@ Log4j2
4249public class RestMLProfileAction extends BaseRestHandler {
4350 private static final String PROFILE_ML_ACTION = "profile_ml" ;
4451
52+ private static final String VIEW = "view" ;
53+ private static final String MODEL_VIEW = "model" ;
54+ private static final String NODE_VIEW = "node" ;
55+
4556 private ClusterService clusterService ;
4657
4758 /**
@@ -80,6 +91,7 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
8091 } else {
8192 mlProfileInput = createMLProfileInputFromRequestParams (request );
8293 }
94+ String view = RestActionUtils .getStringParam (request , VIEW ).orElse (NODE_VIEW );
8395 String [] nodeIds = mlProfileInput .retrieveProfileOnAllNodes ()
8496 ? getAllNodes (clusterService )
8597 : mlProfileInput .getNodeIds ().toArray (new String [0 ]);
@@ -93,7 +105,16 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
93105 List <MLProfileNodeResponse > nodeProfiles = r .getNodes ().stream ().filter (s -> !s .isEmpty ()).collect (Collectors .toList ());
94106 log .debug ("Build MLProfileNodeResponse for size of {}" , nodeProfiles .size ());
95107 if (nodeProfiles .size () > 0 ) {
96- r .toXContent (builder , ToXContent .EMPTY_PARAMS );
108+ if (NODE_VIEW .equals (view )) {
109+ r .toXContent (builder , ToXContent .EMPTY_PARAMS );
110+ } else if (MODEL_VIEW .equals (view )) {
111+ Map <String , MLProfileModelResponse > modelCentricProfileMap = buildModelCentricResult (nodeProfiles );
112+ builder .startObject ("models" );
113+ for (Map .Entry <String , MLProfileModelResponse > entry : modelCentricProfileMap .entrySet ()) {
114+ builder .field (entry .getKey (), entry .getValue ());
115+ }
116+ builder .endObject ();
117+ }
97118 }
98119 builder .endObject ();
99120 channel .sendResponse (new BytesRestResponse (RestStatus .OK , builder ));
@@ -105,6 +126,59 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
105126 };
106127 }
107128
129+ /**
130+ * The data structure for node centric is:
131+ * MLProfileNodeResponse:
132+ * taskMap: Map<String, MLTask>
133+ * modelMap: Map<String, MLModelProfile> model_id, MLModelProfile
134+ * And we need to convert to format like this:
135+ * modelMap: Map<String, Map<String, MLModelProfile>>
136+ */
137+ private Map <String , MLProfileModelResponse > buildModelCentricResult (List <MLProfileNodeResponse > nodeResponses ) {
138+ // aggregate model information into one final map.
139+ Map <String , MLProfileModelResponse > modelCentricMap = new HashMap <>();
140+ for (MLProfileNodeResponse mlProfileNodeResponse : nodeResponses ) {
141+ String nodeId = mlProfileNodeResponse .getNode ().getId ();
142+ Map <String , MLModelProfile > modelProfileMap = mlProfileNodeResponse .getMlNodeModels ();
143+ Map <String , MLTask > taskProfileMap = mlProfileNodeResponse .getMlNodeTasks ();
144+ for (Map .Entry <String , MLModelProfile > entry : modelProfileMap .entrySet ()) {
145+ MLProfileModelResponse mlProfileModelResponse = modelCentricMap .get (entry .getKey ());
146+ if (mlProfileModelResponse == null ) {
147+ mlProfileModelResponse = new MLProfileModelResponse (
148+ entry .getValue ().getTargetWorkerNodes (),
149+ entry .getValue ().getWorkerNodes ()
150+ );
151+ modelCentricMap .put (entry .getKey (), mlProfileModelResponse );
152+ }
153+ if (mlProfileModelResponse .getTargetWorkerNodes () == null || mlProfileModelResponse .getWorkerNodes () == null ) {
154+ mlProfileModelResponse .setTargetWorkerNodes (entry .getValue ().getTargetWorkerNodes ());
155+ mlProfileModelResponse .setWorkerNodes (entry .getValue ().getWorkerNodes ());
156+ }
157+ // Create a new object and remove targetWorkerNodes and workerNodes.
158+ MLModelProfile modelProfile = new MLModelProfile (
159+ entry .getValue ().getModelState (),
160+ entry .getValue ().getPredictor (),
161+ null ,
162+ null ,
163+ entry .getValue ().getModelInferenceStats (),
164+ entry .getValue ().getPredictRequestStats ()
165+ );
166+ mlProfileModelResponse .getMlModelProfileMap ().putAll (ImmutableMap .of (nodeId , modelProfile ));
167+ }
168+
169+ for (Map .Entry <String , MLTask > entry : taskProfileMap .entrySet ()) {
170+ String modelId = entry .getValue ().getModelId ();
171+ MLProfileModelResponse mlProfileModelResponse = modelCentricMap .get (modelId );
172+ if (mlProfileModelResponse == null ) {
173+ mlProfileModelResponse = new MLProfileModelResponse ();
174+ modelCentricMap .put (modelId , mlProfileModelResponse );
175+ }
176+ mlProfileModelResponse .getMlTaskMap ().putAll (ImmutableMap .of (entry .getKey (), entry .getValue ()));
177+ }
178+ }
179+ return modelCentricMap ;
180+ }
181+
108182 MLProfileInput createMLProfileInputFromRequestParams (RestRequest request ) {
109183 MLProfileInput mlProfileInput = new MLProfileInput ();
110184 Optional <String []> modelIds = splitCommaSeparatedParam (request , PARAMETER_MODEL_ID );
0 commit comments