Skip to content

Commit f28f97e

Browse files
Google AI Edgecopybara-github
authored andcommitted
Refactor ModelGraphVisualizer config to use signals and add highlight toggle.
PiperOrigin-RevId: 880804502
1 parent 93edbca commit f28f97e

File tree

6 files changed

+177
-22
lines changed

6 files changed

+177
-22
lines changed

src/ui/src/components/visualizer/common/consts.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ import {IS_EXTERNAL} from '../../../common/flags';
2121
/** The height of the node label. */
2222
export const DEFAULT_NODE_LABEL_HEIGHT = 11;
2323

24+
/** The padding of the label. */
25+
export const LABEL_PADDING = 24;
26+
2427
/**
2528
* The padding between the label and the value in node's attrs table.
2629
*/

src/ui/src/components/visualizer/common/utils.ts

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,14 +1038,99 @@ export function splitLabel(label: string): string[] {
10381038
.filter((line) => line !== '');
10391039
}
10401040

1041+
/** Wraps the label based on the given max width. */
1042+
export function wrapLabel(
1043+
label: string,
1044+
maxWidth: number,
1045+
fontSize: number,
1046+
bold = false,
1047+
): string[] {
1048+
const sections = label.split('\n');
1049+
const allLines: string[] = [];
1050+
1051+
for (const section of sections) {
1052+
if (section === '') {
1053+
allLines.push('');
1054+
continue;
1055+
}
1056+
1057+
const words = section.split(' ');
1058+
const lines: string[] = [];
1059+
let currentLine = words[0];
1060+
1061+
for (let i = 1; i < words.length; i++) {
1062+
const word = words[i];
1063+
const width = getLabelWidth(currentLine + ' ' + word, fontSize, bold);
1064+
if (width < maxWidth) {
1065+
currentLine += ' ' + word;
1066+
} else {
1067+
lines.push(currentLine);
1068+
currentLine = word;
1069+
}
1070+
}
1071+
lines.push(currentLine);
1072+
1073+
// Post-process to split any lines that are still too long (single words).
1074+
const finalLines: string[] = [];
1075+
for (const line of lines) {
1076+
if (getLabelWidth(line, fontSize, bold) <= maxWidth) {
1077+
finalLines.push(line);
1078+
continue;
1079+
}
1080+
// Split long word.
1081+
let curLine = line;
1082+
while (getLabelWidth(curLine, fontSize, bold) > maxWidth) {
1083+
// Find split point.
1084+
let low = 0;
1085+
let high = curLine.length;
1086+
let splitIndex = 0;
1087+
while (low <= high) {
1088+
const mid = Math.floor((low + high) / 2);
1089+
const sub = curLine.substring(0, mid);
1090+
if (getLabelWidth(sub, fontSize, bold) <= maxWidth) {
1091+
splitIndex = mid;
1092+
low = mid + 1;
1093+
} else {
1094+
high = mid - 1;
1095+
}
1096+
}
1097+
if (splitIndex === 0) {
1098+
// If even one char is too wide, just take one char (shouldn't happen with reasonable width).
1099+
splitIndex = 1;
1100+
}
1101+
finalLines.push(curLine.substring(0, splitIndex));
1102+
curLine = curLine.substring(splitIndex);
1103+
}
1104+
if (curLine.length > 0) {
1105+
finalLines.push(curLine);
1106+
}
1107+
}
1108+
allLines.push(...finalLines);
1109+
}
1110+
1111+
return allLines;
1112+
}
1113+
10411114
/** Get the extra height for multi-line label. */
10421115
export function getMultiLineLabelExtraHeight(
10431116
node: ModelNode,
10441117
config?: VisualizerConfig,
10451118
): number {
1046-
return (
1047-
(splitLabel(node.label).length - 1) * getNodeLabelLineHeight(node, config)
1048-
);
1119+
let lineCount = 0;
1120+
if (config?.nodeLabelWidth) {
1121+
const lines = wrapLabel(
1122+
node.label,
1123+
config.nodeLabelWidth,
1124+
getNodeLabelHeight(node, config),
1125+
!isOpNode(node),
1126+
);
1127+
if (lines.length >= 1) {
1128+
return (lines.length - 1) * getNodeLabelLineHeight(node, config);
1129+
}
1130+
} else {
1131+
lineCount = splitLabel(node.label).length;
1132+
}
1133+
return (lineCount - 1) * getNodeLabelLineHeight(node, config);
10491134
}
10501135

10511136
/**
@@ -1329,9 +1414,14 @@ export function getLayoutMarginTop(
13291414
config?: VisualizerConfig,
13301415
): number {
13311416
const nodeLabelHeight = getNodeLabelHeight(node, config);
1417+
let extraHeight = 0;
1418+
if (config?.nodeLabelWidth) {
1419+
extraHeight = getMultiLineLabelExtraHeight(node, config);
1420+
}
1421+
13321422
if (nodeLabelHeight === 11) {
1333-
return 36;
1423+
return 36 + extraHeight;
13341424
}
13351425
const nodeLabelYPadding = getNodeLabelYPadding(node, config);
1336-
return nodeLabelYPadding + nodeLabelHeight + 16;
1426+
return nodeLabelYPadding + nodeLabelHeight + extraHeight + 16;
13371427
}

src/ui/src/components/visualizer/common/visualizer_config.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ export declare interface VisualizerConfig {
7272
/** The maximum number of child nodes under a layer node. Default: 400. */
7373
artificialLayerNodeCountThreshold?: number;
7474

75+
/**
76+
* The maximum width for node labels.
77+
*
78+
* If set, node labels will be wrapped to this width.
79+
* If unset (default), the node labels will have a default maximum width.
80+
*/
81+
nodeLabelWidth?: number;
82+
7583
/** The font size of the edge label. Default: 7.5. */
7684
edgeLabelFontSize?: number;
7785

src/ui/src/components/visualizer/webgl_renderer.ts

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ import {
103103
processNodeStylerRules,
104104
splitLabel,
105105
splitNamespace,
106+
wrapLabel,
106107
} from './common/utils';
107108
import {
108109
ExpandOrCollapseGroupNodeRequest,
@@ -2552,12 +2553,19 @@ export class WebglRenderer implements OnInit, OnChanges, OnDestroy {
25522553

25532554
// Expand icon.
25542555
const iconZ = y + this.getNodeLabelRelativeY(node) + 18.5;
2555-
const leftIconX = node.expanded
2556-
? labelLeft - 13
2557-
: (x + labelLeft + 1) / 2 + 1;
2558-
const rightIconX = node.expanded
2559-
? labelRight + 12
2560-
: (x + width + labelRight - 1) / 2 - 1;
2556+
let leftIconX = 0;
2557+
let rightIconX = 0;
2558+
if (this.appService.config()?.nodeLabelWidth) {
2559+
leftIconX = x + 12;
2560+
rightIconX = x + width - 12;
2561+
} else {
2562+
leftIconX = node.expanded
2563+
? labelLeft - 13
2564+
: (x + labelLeft + 1) / 2 + 1;
2565+
rightIconX = node.expanded
2566+
? labelRight + 12
2567+
: (x + width + labelRight - 1) / 2 - 1;
2568+
}
25612569
groupNodeIcons.push({
25622570
id: node.id,
25632571
nodeId: node.id,
@@ -2784,7 +2792,19 @@ export class WebglRenderer implements OnInit, OnChanges, OnDestroy {
27842792

27852793
// Font size.
27862794
const labelHeight = getNodeLabelHeight(node, this.appService.config());
2787-
const lines = splitLabel(this.getNodeLabel(node));
2795+
let lines: string[] = [];
2796+
const config = this.appService.config();
2797+
if (config?.nodeLabelWidth) {
2798+
lines = wrapLabel(
2799+
this.getNodeLabel(node),
2800+
config.nodeLabelWidth,
2801+
labelHeight,
2802+
!isOpNode(node),
2803+
);
2804+
} else {
2805+
lines = splitLabel(this.getNodeLabel(node));
2806+
}
2807+
27882808
for (let i = 0; i < lines.length; i++) {
27892809
const curLineLabel = lines[i];
27902810
labels.push({

src/ui/src/components/visualizer/webgl_renderer_search_results_service.ts

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ import {
2727
getNodeLabelHeight,
2828
getNodeLabelLineHeight,
2929
isGroupNode,
30+
isOpNode,
3031
splitLabel,
32+
wrapLabel,
3133
} from './common/utils';
3234
import {
3335
ColorVariable,
@@ -177,8 +179,19 @@ export class WebglRendererSearchResultsService {
177179
let y = 0;
178180
let height = 0;
179181
let width = 0;
180-
const lines = splitLabel(node.label);
181-
if (lines.length === 1) {
182+
const config = this.webglRenderer.appService.config();
183+
let lines: string[] = [];
184+
if (config?.nodeLabelWidth) {
185+
lines = wrapLabel(
186+
node.label,
187+
config.nodeLabelWidth,
188+
nodeLabelHeight,
189+
!isOpNode(node),
190+
);
191+
} else {
192+
lines = splitLabel(node.label);
193+
}
194+
if (lines.length === 1 && !config?.nodeLabelWidth) {
182195
const labelSizes = this.webglRenderer.texts.getLabelSizes(
183196
node.label,
184197
isGroupNode(node) ? FontWeight.BOLD : FontWeight.MEDIUM,
@@ -191,8 +204,19 @@ export class WebglRendererSearchResultsService {
191204
this.webglRenderer.getNodeLabelRelativeY(node) -
192205
2 * scale;
193206
} else {
194-
const {minX, maxX} = this.webglRenderer.getNodeLabelSizes(node);
195-
width = (maxX - minX) * scale + 4;
207+
let maxLineWidth = 0;
208+
for (const line of lines) {
209+
const labelSizes = this.webglRenderer.texts.getLabelSizes(
210+
line,
211+
isGroupNode(node) ? FontWeight.BOLD : FontWeight.MEDIUM,
212+
nodeLabelHeight,
213+
).sizes;
214+
maxLineWidth = Math.max(
215+
maxLineWidth,
216+
labelSizes.maxX - labelSizes.minX,
217+
);
218+
}
219+
width = maxLineWidth * scale + 4;
196220
height =
197221
lines.length *
198222
getNodeLabelLineHeight(

src/ui/src/components/visualizer/worker/graph_layout.ts

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
EXPANDED_NODE_DATA_PROVIDER_SUMMARY_ROW_HEIGHT,
2323
EXPANDED_NODE_DATA_PROVIDER_SUMMARY_TOP_PADDING,
2424
EXPANDED_NODE_DATA_PROVIDER_SYUMMARY_FONT_SIZE,
25+
LABEL_PADDING,
2526
LAYOUT_MARGIN_X,
2627
MAX_IO_ROWS_IN_ATTRS_TABLE,
2728
NODE_ATTRS_TABLE_FONT_SIZE_TO_HEIGHT_RATIO,
@@ -69,6 +70,7 @@ import {
6970
isGroupNode,
7071
isOpNode,
7172
splitLabel,
73+
wrapLabel,
7274
} from '../common/utils';
7375
import {VisualizerConfig} from '../common/visualizer_config';
7476

@@ -83,8 +85,6 @@ export const LAYOUT_MARGIN_BOTTOM = 16;
8385
/** Node width for test cases. */
8486
export const NODE_WIDTH_FOR_TEST = 50;
8587

86-
const LABEL_PADDING = 24;
87-
8888
const MIN_NODE_WIDTH = 80;
8989

9090
const ATTRS_TABLE_MARGIN_X = 8;
@@ -417,8 +417,8 @@ export function getNodeWidth(
417417
testMode = false,
418418
config?: VisualizerConfig,
419419
) {
420-
// Always return 32 in test mode.
421-
if (testMode) {
420+
// Always return 32 in test mode, unless nodeLabelWidth is set.
421+
if (testMode && !config?.nodeLabelWidth) {
422422
return NODE_WIDTH_FOR_TEST;
423423
}
424424

@@ -428,7 +428,17 @@ export function getNodeWidth(
428428
fontSize * NODE_ATTRS_TABLE_VALUE_MAX_CHAR_COUNT;
429429

430430
const label = node.label;
431-
const lines = splitLabel(label);
431+
let lines: string[] = [];
432+
if (config?.nodeLabelWidth) {
433+
lines = wrapLabel(
434+
label,
435+
config.nodeLabelWidth,
436+
getNodeLabelHeight(node, config),
437+
isGroupNode(node),
438+
);
439+
} else {
440+
lines = splitLabel(label);
441+
}
432442
let labelWidth = 0;
433443
for (const line of lines) {
434444
labelWidth = Math.max(
@@ -601,7 +611,7 @@ export function getNodeHeight(
601611
forceRecalculate = false,
602612
config?: VisualizerConfig,
603613
) {
604-
if (testMode) {
614+
if (testMode && !config?.nodeLabelWidth) {
605615
return NODE_HEIGHT_FOR_TEST;
606616
}
607617

0 commit comments

Comments
 (0)