1919#include < limits>
2020#include < optional>
2121#include < string>
22+ #include < tuple>
2223#include < utility>
2324#include < vector>
2425
3233#include " absl/status/statusor.h"
3334#include " absl/strings/str_format.h"
3435#include " absl/strings/str_join.h"
36+ #include " absl/types/span.h"
3537#include " clang/include/clang/AST/Decl.h"
3638#include " xls/common/math_util.h"
3739#include " xls/common/status/status_macros.h"
@@ -635,6 +637,8 @@ NewFSMGenerator::GenerateNewFSMInvocation(
635637 XLS_ASSIGN_OR_RETURN (layout,
636638 LayoutNewFSM (func, state_element_for_static, body_loc));
637639
640+ absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue> generated_conditions;
641+
638642 const int64_t num_slice_index_bits =
639643 xls::CeilOfLog2 (1 + xls_func->slices .size ());
640644
@@ -707,7 +711,7 @@ NewFSMGenerator::GenerateNewFSMInvocation(
707711 XLS_ASSIGN_OR_RETURN (
708712 phi_elements_by_param_node_id,
709713 GeneratePhiConditions (layout, state_element_by_jump_slice_index, pb,
710- body_loc));
714+ body_loc, generated_conditions ));
711715
712716 // The value from the current activation's perspective,
713717 // either outputted from invoke or state element.
@@ -763,6 +767,30 @@ NewFSMGenerator::GenerateNewFSMInvocation(
763767 TrackedBValue after_activation_transition =
764768 pb.Literal (xls::UBits (0 , 1 ), body_loc);
765769
770+ // Sort by Node ID and StateElement name for determinism.
771+ struct StateElementAndNodeLessThan {
772+ bool operator ()(const std::tuple<xls::StateElement*, xls::Node*>& a,
773+ const std::tuple<xls::StateElement*, xls::Node*>& b) const {
774+ const auto & [a_elem, a_node] = a;
775+ const auto & [b_elem, b_node] = b;
776+ if (a_elem->name () != b_elem->name ()) {
777+ return a_elem->name () < b_elem->name ();
778+ }
779+ return a_node->id () < b_node->id ();
780+ }
781+ };
782+
783+ struct NodeIdLessThan {
784+ bool operator ()(const xls::Node* a, const xls::Node* b) const {
785+ return a->id () < b->id ();
786+ }
787+ };
788+
789+ absl::btree_map<std::tuple<xls::StateElement*, xls::Node*>,
790+ absl::btree_set<xls::Node*, NodeIdLessThan>,
791+ StateElementAndNodeLessThan>
792+ next_value_conditions_by_state_element_and_value;
793+
766794 for (int64_t slice_index = 0 ; slice_index < func.slices .size ();
767795 ++slice_index) {
768796 const bool is_last_slice = (slice_index == func.slices .size () - 1 );
@@ -1026,44 +1054,65 @@ NewFSMGenerator::GenerateNewFSMInvocation(
10261054 if (state.slice_index != slice_index) {
10271055 continue ;
10281056 }
1029-
10301057 absl::btree_set<int64_t > jumped_from_slice_indices_this_state;
10311058 for (const JumpInfo& jump_info : state.jumped_from_slice_indices ) {
10321059 jumped_from_slice_indices_this_state.insert (jump_info.from_slice );
10331060 }
10341061
10351062 XLS_ASSIGN_OR_RETURN (
10361063 TrackedBValue state_active_condition,
1037- GeneratePhiCondition (from_jump_slice_indices,
1038- jumped_from_slice_indices_this_state,
1039- state_element_by_jump_slice_index, pb,
1040- state. slice_index , body_loc ));
1064+ GeneratePhiCondition (
1065+ from_jump_slice_indices, jumped_from_slice_indices_this_state,
1066+ state_element_by_jump_slice_index, pb, state. slice_index ,
1067+ body_loc, generated_conditions ));
10411068
10421069 TrackedBValue next_value_condition =
10431070 pb.And (state_active_condition, jump_condition, body_loc,
10441071 /* name=*/ GetIRStateName (state));
10451072
10461073 for (const ContinuationValue* continuation_out : state.values_to_save ) {
10471074 // Generate next values for state elements
1048- NextStateValue next_value = {
1049- .priority = 0 ,
1050- .value = value_by_continuation_value.at (continuation_out),
1051- .condition = next_value_condition,
1052- };
1053-
10541075 xls::StateElement* state_elem =
10551076 state_element_by_continuation_value.at (continuation_out)
10561077 .node ()
10571078 ->As <xls::StateRead>()
10581079 ->state_element ();
10591080
1081+ std::tuple<xls::StateElement*, xls::Node*> key = {
1082+ state_elem,
1083+ value_by_continuation_value.at (continuation_out).node ()};
1084+
10601085 // Generate next values
1061- extra_next_state_values.insert ({state_elem, next_value});
1086+ next_value_conditions_by_state_element_and_value[key].insert (
1087+ next_value_condition.node ());
10621088 }
10631089 }
10641090 }
10651091 }
10661092
1093+ for (auto & [key, or_nodes] :
1094+ next_value_conditions_by_state_element_and_value) {
1095+ xls::StateElement* state_elem = std::get<0 >(key);
1096+ xls::Node* next_value_node = std::get<1 >(key);
1097+ std::vector<NATIVE_BVAL> or_bvals;
1098+ for (xls::Node* or_node : or_nodes) {
1099+ or_bvals.push_back (NATIVE_BVAL (or_node, &pb));
1100+ }
1101+
1102+ TrackedBValue or_bval =
1103+ pb.Or (absl::MakeSpan (or_bvals), body_loc,
1104+ /* name=*/
1105+ absl::StrFormat (" %s_v_%s_or_bval" , state_elem->name (),
1106+ next_value_node->GetName ()));
1107+
1108+ NextStateValue next_value = {
1109+ .priority = 0 ,
1110+ .value = TrackedBValue (next_value_node, &pb),
1111+ .condition = or_bval,
1112+ };
1113+ extra_next_state_values.insert ({state_elem, next_value});
1114+ }
1115+
10671116 // Set next slice index
10681117 const TrackedBValue finished_iteration =
10691118 pb.Not (after_activation_transition, body_loc,
@@ -1098,8 +1147,16 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
10981147 const absl::btree_set<int64_t >& jumped_from_slice_indices_this_state,
10991148 const absl::flat_hash_map<int64_t , TrackedBValue>&
11001149 state_element_by_jump_slice_index,
1101- xls::ProcBuilder& pb, int64_t slice_index,
1102- const xls::SourceInfo& body_loc) {
1150+ xls::ProcBuilder& pb, int64_t slice_index, const xls::SourceInfo& body_loc,
1151+ absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
1152+ phi_condition_cache) {
1153+ PhiConditionCacheKey key = {from_jump_slice_indices,
1154+ jumped_from_slice_indices_this_state};
1155+
1156+ if (phi_condition_cache.contains (key)) {
1157+ return phi_condition_cache.at (key);
1158+ }
1159+
11031160 TrackedBValue condition = pb.Literal (xls::UBits (1 , 1 ), body_loc);
11041161
11051162 // Include all jump slices in each condition
@@ -1110,6 +1167,7 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
11101167 jumped_from_slice_indices_this_state.contains (from_jump_slice_index)
11111168 ? 1
11121169 : 0 ;
1170+
11131171 TrackedBValue condition_part =
11141172 pb.Eq (jump_state_element,
11151173 pb.Literal (xls::UBits (active_value, 1 ), body_loc,
@@ -1125,6 +1183,7 @@ absl::StatusOr<TrackedBValue> NewFSMGenerator::GeneratePhiCondition(
11251183 absl::StrJoin (jumped_from_slice_indices_this_state, " _" )));
11261184 }
11271185
1186+ phi_condition_cache[key] = condition;
11281187 return condition;
11291188}
11301189
@@ -1134,7 +1193,9 @@ NewFSMGenerator::GeneratePhiConditions(
11341193 const NewFSMLayout& layout,
11351194 const absl::flat_hash_map<int64_t , TrackedBValue>&
11361195 state_element_by_jump_slice_index,
1137- xls::ProcBuilder& pb, const xls::SourceInfo& body_loc) {
1196+ xls::ProcBuilder& pb, const xls::SourceInfo& body_loc,
1197+ absl::flat_hash_map<PhiConditionCacheKey, TrackedBValue>&
1198+ phi_condition_cache) {
11381199 absl::flat_hash_map<int64_t , std::vector<PhiElement>>
11391200 phi_elements_by_param_node_id;
11401201
@@ -1171,10 +1232,10 @@ NewFSMGenerator::GeneratePhiConditions(
11711232
11721233 XLS_ASSIGN_OR_RETURN (
11731234 TrackedBValue condition,
1174- GeneratePhiCondition (from_jump_slice_indices,
1175- jumped_from_slice_indices_this_state,
1176- state_element_by_jump_slice_index, pb,
1177- state-> slice_index , body_loc ));
1235+ GeneratePhiCondition (
1236+ from_jump_slice_indices, jumped_from_slice_indices_this_state,
1237+ state_element_by_jump_slice_index, pb, state-> slice_index ,
1238+ body_loc, phi_condition_cache ));
11781239
11791240 PhiElement& phi_element = phi_elements.emplace_back ();
11801241 phi_element.value = state->current_inputs_by_input_param .at (param);
@@ -1204,13 +1265,47 @@ NewFSMGenerator::GenerateInputValueInContext(
12041265 std::vector<TrackedBValue> phi_conditions;
12051266 std::vector<TrackedBValue> phi_values;
12061267
1268+ // Sort by Node ID for determinism.
1269+ struct NodeIdLessThan {
1270+ bool operator ()(const xls::Node* a, const xls::Node* b) const {
1271+ return a->id () < b->id ();
1272+ }
1273+ };
1274+ struct BValueIdLessThan {
1275+ bool operator ()(const TrackedBValue& a, const TrackedBValue& b) const {
1276+ return a.node ()->id () < b.node ()->id ();
1277+ }
1278+ };
1279+ absl::btree_map<xls::Node*, absl::btree_set<TrackedBValue, BValueIdLessThan>,
1280+ NodeIdLessThan>
1281+ conditions_by_value_node;
1282+
12071283 phi_conditions.reserve (phi_elements.size ());
12081284 phi_values.reserve (phi_elements.size ());
1285+
12091286 for (const PhiElement& phi_element : phi_elements) {
1210- phi_conditions.push_back (phi_element.condition );
12111287 XLSCC_CHECK (value_by_continuation_value.contains (phi_element.value ),
12121288 phi_element.value ->output_node ->loc ());
1213- phi_values.push_back (value_by_continuation_value.at (phi_element.value ));
1289+
1290+ xls::Node* value_node =
1291+ value_by_continuation_value.at (phi_element.value ).node ();
1292+ conditions_by_value_node[value_node].insert (phi_element.condition );
1293+ }
1294+
1295+ for (auto & [value_node, or_nodes] : conditions_by_value_node) {
1296+ std::vector<NATIVE_BVAL> or_bvals;
1297+ or_bvals.reserve (or_nodes.size ());
1298+ for (const TrackedBValue& or_node : or_nodes) {
1299+ or_bvals.push_back (or_node);
1300+ }
1301+
1302+ TrackedBValue or_bval =
1303+ pb.Or (absl::MakeSpan (or_bvals), body_loc,
1304+ /* name=*/
1305+ absl::StrFormat (" %s_v_%s_or_bval" , param->name (),
1306+ value_node->GetName ()));
1307+ phi_conditions.push_back (or_bval);
1308+ phi_values.push_back (TrackedBValue (value_node, &pb));
12141309 }
12151310
12161311 std::reverse (phi_conditions.begin (), phi_conditions.end ());
0 commit comments