Skip to content

Commit 8a47070

Browse files
Sean Purser-Haskellcopybara-github
authored andcommitted
Change how new FSM generates next state values.
- Reduce build times, fixing many timeouts - Reduce the size of generated IR by orders of magnitude when there are many states, as with deeply nested loops - Improve QoR 1. Caching of phi conditions in GeneratePhiCondition() reduces the size of the IR considerably. 2. The caching from #1 helps to enable the consolidation of entries added to extra_next_state_values, which ultimately feed into a priority select for the next value of the state element. The XLS back-end does much better with this reduced form, with fewer bits in the priority select, fed by ORs over multiple conditions. PiperOrigin-RevId: 863297433
1 parent 5cbee79 commit 8a47070

File tree

3 files changed

+131
-27
lines changed

3 files changed

+131
-27
lines changed

xls/contrib/xlscc/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ cc_library(
202202
"@com_google_absl//absl/status:statusor",
203203
"@com_google_absl//absl/strings",
204204
"@com_google_absl//absl/strings:str_format",
205+
"@com_google_absl//absl/types:span",
205206
"@llvm-project//clang:ast",
206207
],
207208
)

xls/contrib/xlscc/generate_fsm.cc

Lines changed: 117 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <limits>
2020
#include <optional>
2121
#include <string>
22+
#include <tuple>
2223
#include <utility>
2324
#include <vector>
2425

@@ -32,6 +33,7 @@
3233
#include "absl/status/statusor.h"
3334
#include "absl/strings/str_format.h"
3435
#include "absl/strings/str_join.h"
36+
#include "absl/types/span.h"
3537
#include "clang/include/clang/AST/Decl.h"
3638
#include "xls/common/math_util.h"
3739
#include "xls/common/status/status_macros.h"
@@ -635,6 +637,8 @@ NewFSMGenerator::GenerateNewFSMInvocation(
635637
XLS_ASSIGN_OR_RETURN(layout,
636638
LayoutNewFSM(func, state_element_for_static, body_loc));
637639

640+
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue> generated_conditions;
641+
638642
const int64_t num_slice_index_bits =
639643
xls::CeilOfLog2(1 + xls_func->slices.size());
640644

@@ -707,7 +711,7 @@ NewFSMGenerator::GenerateNewFSMInvocation(
707711
XLS_ASSIGN_OR_RETURN(
708712
phi_elements_by_param_node_id,
709713
GeneratePhiConditions(layout, state_element_by_jump_slice_index, pb,
710-
body_loc));
714+
body_loc, generated_conditions));
711715

712716
// The value from the current activation's perspective,
713717
// either outputted from invoke or state element.
@@ -763,6 +767,30 @@ NewFSMGenerator::GenerateNewFSMInvocation(
763767
TrackedBValue after_activation_transition =
764768
pb.Literal(xls::UBits(0, 1), body_loc);
765769

770+
// Sort by Node ID and StateElement name for determinism.
771+
struct StateElementAndNodeLessThan {
772+
bool operator()(const std::tuple<xls::StateElement*, xls::Node*>& a,
773+
const std::tuple<xls::StateElement*, xls::Node*>& b) const {
774+
const auto& [a_elem, a_node] = a;
775+
const auto& [b_elem, b_node] = b;
776+
if (a_elem->name() != b_elem->name()) {
777+
return a_elem->name() < b_elem->name();
778+
}
779+
return a_node->id() < b_node->id();
780+
}
781+
};
782+
783+
struct NodeIdLessThan {
784+
bool operator()(const xls::Node* a, const xls::Node* b) const {
785+
return a->id() < b->id();
786+
}
787+
};
788+
789+
absl::btree_map<std::tuple<xls::StateElement*, xls::Node*>,
790+
absl::btree_set<xls::Node*, NodeIdLessThan>,
791+
StateElementAndNodeLessThan>
792+
next_value_conditions_by_state_element_and_value;
793+
766794
for (int64_t slice_index = 0; slice_index < func.slices.size();
767795
++slice_index) {
768796
const bool is_last_slice = (slice_index == func.slices.size() - 1);
@@ -1026,44 +1054,65 @@ NewFSMGenerator::GenerateNewFSMInvocation(
10261054
if (state.slice_index != slice_index) {
10271055
continue;
10281056
}
1029-
10301057
absl::btree_set<int64_t> jumped_from_slice_indices_this_state;
10311058
for (const JumpInfo& jump_info : state.jumped_from_slice_indices) {
10321059
jumped_from_slice_indices_this_state.insert(jump_info.from_slice);
10331060
}
10341061

10351062
XLS_ASSIGN_OR_RETURN(
10361063
TrackedBValue state_active_condition,
1037-
GeneratePhiCondition(from_jump_slice_indices,
1038-
jumped_from_slice_indices_this_state,
1039-
state_element_by_jump_slice_index, pb,
1040-
state.slice_index, body_loc));
1064+
GeneratePhiCondition(
1065+
from_jump_slice_indices, jumped_from_slice_indices_this_state,
1066+
state_element_by_jump_slice_index, pb, state.slice_index,
1067+
body_loc, generated_conditions));
10411068

10421069
TrackedBValue next_value_condition =
10431070
pb.And(state_active_condition, jump_condition, body_loc,
10441071
/*name=*/GetIRStateName(state));
10451072

10461073
for (const ContinuationValue* continuation_out : state.values_to_save) {
10471074
// Generate next values for state elements
1048-
NextStateValue next_value = {
1049-
.priority = 0,
1050-
.value = value_by_continuation_value.at(continuation_out),
1051-
.condition = next_value_condition,
1052-
};
1053-
10541075
xls::StateElement* state_elem =
10551076
state_element_by_continuation_value.at(continuation_out)
10561077
.node()
10571078
->As<xls::StateRead>()
10581079
->state_element();
10591080

1081+
std::tuple<xls::StateElement*, xls::Node*> key = {
1082+
state_elem,
1083+
value_by_continuation_value.at(continuation_out).node()};
1084+
10601085
// Generate next values
1061-
extra_next_state_values.insert({state_elem, next_value});
1086+
next_value_conditions_by_state_element_and_value[key].insert(
1087+
next_value_condition.node());
10621088
}
10631089
}
10641090
}
10651091
}
10661092

1093+
for (auto& [key, or_nodes] :
1094+
next_value_conditions_by_state_element_and_value) {
1095+
xls::StateElement* state_elem = std::get<0>(key);
1096+
xls::Node* next_value_node = std::get<1>(key);
1097+
std::vector<NATIVE_BVAL> or_bvals;
1098+
for (xls::Node* or_node : or_nodes) {
1099+
or_bvals.push_back(NATIVE_BVAL(or_node, &pb));
1100+
}
1101+
1102+
TrackedBValue or_bval =
1103+
pb.Or(absl::MakeSpan(or_bvals), body_loc,
1104+
/*name=*/
1105+
absl::StrFormat("%s_v_%s_or_bval", state_elem->name(),
1106+
next_value_node->GetName()));
1107+
1108+
NextStateValue next_value = {
1109+
.priority = 0,
1110+
.value = TrackedBValue(next_value_node, &pb),
1111+
.condition = or_bval,
1112+
};
1113+
extra_next_state_values.insert({state_elem, next_value});
1114+
}
1115+
10671116
// Set next slice index
10681117
const TrackedBValue finished_iteration =
10691118
pb.Not(after_activation_transition, body_loc,
@@ -1098,8 +1147,16 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
10981147
const absl::btree_set<int64_t>& jumped_from_slice_indices_this_state,
10991148
const absl::flat_hash_map<int64_t, TrackedBValue>&
11001149
state_element_by_jump_slice_index,
1101-
xls::ProcBuilder& pb, int64_t slice_index,
1102-
const xls::SourceInfo& body_loc) {
1150+
xls::ProcBuilder& pb, int64_t slice_index, const xls::SourceInfo& body_loc,
1151+
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
1152+
phi_condition_cache) {
1153+
PhiConditionCacheKey key = {from_jump_slice_indices,
1154+
jumped_from_slice_indices_this_state};
1155+
1156+
if (phi_condition_cache.contains(key)) {
1157+
return phi_condition_cache.at(key);
1158+
}
1159+
11031160
TrackedBValue condition = pb.Literal(xls::UBits(1, 1), body_loc);
11041161

11051162
// Include all jump slices in each condition
@@ -1110,6 +1167,7 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
11101167
jumped_from_slice_indices_this_state.contains(from_jump_slice_index)
11111168
? 1
11121169
: 0;
1170+
11131171
TrackedBValue condition_part =
11141172
pb.Eq(jump_state_element,
11151173
pb.Literal(xls::UBits(active_value, 1), body_loc,
@@ -1125,6 +1183,7 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
11251183
absl::StrJoin(jumped_from_slice_indices_this_state, "_")));
11261184
}
11271185

1186+
phi_condition_cache[key] = condition;
11281187
return condition;
11291188
}
11301189

@@ -1134,7 +1193,9 @@ NewFSMGenerator::GeneratePhiConditions(
11341193
const NewFSMLayout& layout,
11351194
const absl::flat_hash_map<int64_t, TrackedBValue>&
11361195
state_element_by_jump_slice_index,
1137-
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc) {
1196+
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc,
1197+
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
1198+
phi_condition_cache) {
11381199
absl::flat_hash_map<int64_t, std::vector<PhiElement>>
11391200
phi_elements_by_param_node_id;
11401201

@@ -1171,10 +1232,10 @@ NewFSMGenerator::GeneratePhiConditions(
11711232

11721233
XLS_ASSIGN_OR_RETURN(
11731234
TrackedBValue condition,
1174-
GeneratePhiCondition(from_jump_slice_indices,
1175-
jumped_from_slice_indices_this_state,
1176-
state_element_by_jump_slice_index, pb,
1177-
state->slice_index, body_loc));
1235+
GeneratePhiCondition(
1236+
from_jump_slice_indices, jumped_from_slice_indices_this_state,
1237+
state_element_by_jump_slice_index, pb, state->slice_index,
1238+
body_loc, phi_condition_cache));
11781239

11791240
PhiElement& phi_element = phi_elements.emplace_back();
11801241
phi_element.value = state->current_inputs_by_input_param.at(param);
@@ -1204,13 +1265,47 @@ NewFSMGenerator::GenerateInputValueInContext(
12041265
std::vector<TrackedBValue> phi_conditions;
12051266
std::vector<TrackedBValue> phi_values;
12061267

1268+
// Sort by Node ID for determinism.
1269+
struct NodeIdLessThan {
1270+
bool operator()(const xls::Node* a, const xls::Node* b) const {
1271+
return a->id() < b->id();
1272+
}
1273+
};
1274+
struct BValueIdLessThan {
1275+
bool operator()(const TrackedBValue& a, const TrackedBValue& b) const {
1276+
return a.node()->id() < b.node()->id();
1277+
}
1278+
};
1279+
absl::btree_map<xls::Node*, absl::btree_set<TrackedBValue, BValueIdLessThan>,
1280+
NodeIdLessThan>
1281+
conditions_by_value_node;
1282+
12071283
phi_conditions.reserve(phi_elements.size());
12081284
phi_values.reserve(phi_elements.size());
1285+
12091286
for (const PhiElement& phi_element : phi_elements) {
1210-
phi_conditions.push_back(phi_element.condition);
12111287
XLSCC_CHECK(value_by_continuation_value.contains(phi_element.value),
12121288
phi_element.value->output_node->loc());
1213-
phi_values.push_back(value_by_continuation_value.at(phi_element.value));
1289+
1290+
xls::Node* value_node =
1291+
value_by_continuation_value.at(phi_element.value).node();
1292+
conditions_by_value_node[value_node].insert(phi_element.condition);
1293+
}
1294+
1295+
for (auto& [value_node, or_nodes] : conditions_by_value_node) {
1296+
std::vector<NATIVE_BVAL> or_bvals;
1297+
or_bvals.reserve(or_nodes.size());
1298+
for (const TrackedBValue& or_node : or_nodes) {
1299+
or_bvals.push_back(or_node);
1300+
}
1301+
1302+
TrackedBValue or_bval =
1303+
pb.Or(absl::MakeSpan(or_bvals), body_loc,
1304+
/*name=*/
1305+
absl::StrFormat("%s_v_%s_or_bval", param->name(),
1306+
value_node->GetName()));
1307+
phi_conditions.push_back(or_bval);
1308+
phi_values.push_back(TrackedBValue(value_node, &pb));
12141309
}
12151310

12161311
std::reverse(phi_conditions.begin(), phi_conditions.end());

xls/contrib/xlscc/generate_fsm.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,19 +145,27 @@ class NewFSMGenerator : public GeneratorBase {
145145
const ContinuationValue* value;
146146
};
147147

148+
typedef std::tuple<absl::btree_set<int64_t>, absl::btree_set<int64_t>>
149+
PhiConditionCacheKey;
150+
148151
absl::StatusOr<absl::flat_hash_map<int64_t, std::vector<PhiElement>>>
149-
GeneratePhiConditions(const NewFSMLayout& layout,
150-
const absl::flat_hash_map<int64_t, TrackedBValue>&
151-
state_element_by_jump_slice_index,
152-
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc);
152+
GeneratePhiConditions(
153+
const NewFSMLayout& layout,
154+
const absl::flat_hash_map<int64_t, TrackedBValue>&
155+
state_element_by_jump_slice_index,
156+
xls::ProcBuilder& pb, const xls::SourceInfo& body_loc,
157+
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
158+
phi_condition_cache);
153159

154160
absl::StatusOr<TrackedBValue> GeneratePhiCondition(
155161
const absl::btree_set<int64_t>& from_jump_slice_indices,
156162
const absl::btree_set<int64_t>& jumped_from_slice_indices_this_state,
157163
const absl::flat_hash_map<int64_t, TrackedBValue>&
158164
state_element_by_jump_slice_index,
159165
xls::ProcBuilder& pb, int64_t slice_index,
160-
const xls::SourceInfo& body_loc);
166+
const xls::SourceInfo& body_loc,
167+
absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
168+
phi_condition_cache);
161169

162170
absl::StatusOr<std::optional<TrackedBValue>> GenerateInputValueInContext(
163171
const xls::Param* param,

0 commit comments

Comments
 (0)