Skip to content

Commit 8573a98

Browse files
authored
chore: improve sliver index-conversion (#2929)
This reduces unnecessary generics from sliver-index/pair-index conversion code, which allows streamlining some code and improving index checks.
1 parent 0ef70a6 commit 8573a98

File tree

9 files changed

+50
-50
lines changed

9 files changed

+50
-50
lines changed

crates/walrus-core/src/encoding/blob_encoding.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use super::{
2424
use crate::{
2525
SliverIndex,
2626
SliverPairIndex,
27+
SliverType,
2728
encoding::{ReedSolomonEncoder, config::EncodingFactory as _},
2829
ensure,
2930
merkle::{MerkleTree, Node, leaf_hash},
@@ -670,7 +671,7 @@ impl<'a> ExpandedMessageMatrix<'a> {
670671
.map(move |row| {
671672
row[SliverPairIndex::try_from(col_index)
672673
.expect("size has already been checked")
673-
.to_sliver_index::<Secondary>(self.config.n_shards())
674+
.to_sliver_index(self.config.n_shards(), SliverType::Secondary)
674675
.as_usize()]
675676
.as_ref()
676677
})

crates/walrus-core/src/encoding/slivers.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use super::{
2929
use crate::{
3030
SliverIndex,
3131
SliverPairIndex,
32+
SliverType,
3233
encoding::{DecodeError, RequiredCount},
3334
ensure,
3435
inconsistency::{InconsistencyProof, SliverOrInconsistencyProof},
@@ -130,7 +131,7 @@ impl<T: EncodingAxis> SliverData<T> {
130131
.hashes()
131132
.get(
132133
self.index
133-
.to_pair_index::<T>(encoding_config.n_shards())
134+
.to_pair_index(encoding_config.n_shards(), T::sliver_type())
134135
.as_usize(),
135136
)
136137
.expect("hash must exist if all size checks have been performed");
@@ -198,7 +199,7 @@ impl<T: EncodingAxis> SliverData<T> {
198199

199200
let recovery_symbols = self.recovery_symbols(config)?;
200201
let target_sliver_index =
201-
target_pair_index.to_sliver_index::<T::OrthogonalAxis>(config.n_shards());
202+
target_pair_index.to_sliver_index(config.n_shards(), T::OrthogonalAxis::sliver_type());
202203

203204
Ok(recovery_symbols
204205
.decoding_symbol_at(target_sliver_index.as_usize(), self.index.into())
@@ -226,7 +227,7 @@ impl<T: EncodingAxis> SliverData<T> {
226227
) -> Result<DecodingSymbol<T::OrthogonalAxis>, RecoverySymbolError> {
227228
Self::check_index(target_pair_index.into(), config.n_shards())?;
228229
let target_sliver_index =
229-
target_pair_index.to_sliver_index::<T::OrthogonalAxis>(config.n_shards());
230+
target_pair_index.to_sliver_index(config.n_shards(), T::OrthogonalAxis::sliver_type());
230231

231232
Ok(DecodingSymbol::<T::OrthogonalAxis>::new(
232233
self.index.get(),
@@ -457,7 +458,7 @@ impl SliverPair {
457458
secondary: SliverData::new_empty(
458459
config.n_primary_source_symbols().get(),
459460
symbol_size,
460-
index.to_sliver_index::<Secondary>(config.n_shards()),
461+
index.to_sliver_index(config.n_shards(), SliverType::Secondary),
461462
),
462463
}
463464
}

crates/walrus-core/src/encoding/symbols.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -520,12 +520,9 @@ impl<U: MerkleAuth> GeneralRecoverySymbol<U> {
520520
let source_index = SliverIndex(self.symbol.index());
521521

522522
let source_sliver_type = self.symbol.source_type();
523-
let source_index = match source_sliver_type {
524-
SliverType::Primary => source_index.to_pair_index::<Primary>(n_shards),
525-
SliverType::Secondary => source_index.to_pair_index::<Secondary>(n_shards),
526-
};
523+
let source_pair_index = source_index.to_pair_index(n_shards, source_sliver_type);
527524

528-
metadata.get_sliver_hash(source_index, source_sliver_type)
525+
metadata.get_sliver_hash(source_pair_index, source_sliver_type)
529526
}
530527
}
531528

@@ -630,11 +627,12 @@ impl<T: EncodingAxis, U: MerkleAuth> RecoverySymbol<T, U> {
630627
if self.symbol.len() != expected_symbol_size {
631628
return Err(SymbolVerificationError::SymbolSizeMismatch);
632629
}
630+
let verification_axis = T::OrthogonalAxis::sliver_type();
633631
self.verify_proof(
634632
metadata
635633
.get_sliver_hash(
636-
SliverIndex(self.symbol.index).to_pair_index::<T::OrthogonalAxis>(n_shards),
637-
T::OrthogonalAxis::sliver_type(),
634+
SliverIndex(self.symbol.index).to_pair_index(n_shards, verification_axis),
635+
verification_axis,
638636
)
639637
.ok_or(SymbolVerificationError::InvalidMetadata)?,
640638
n_shards.get().into(),
@@ -797,7 +795,8 @@ mod tests {
797795
let (sliver_pairs, metadata) = config_enum.encode_with_metadata(blob)?;
798796

799797
let sliver = sliver_pairs[0].secondary.clone();
800-
let source_index = SliverPairIndex(0).to_sliver_index::<Secondary>(config.n_shards);
798+
let source_index =
799+
SliverPairIndex(0).to_sliver_index(config.n_shards, SliverType::Secondary);
801800

802801
for index in 0..n_shards {
803802
let target_index = SliverIndex(index);

crates/walrus-core/src/lib.rs

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@ use encoding::{
2626
EncodingAxis,
2727
EncodingConfig,
2828
EncodingConfigEnum,
29-
Primary,
3029
PrimaryRecoverySymbol,
3130
PrimarySliver,
3231
QuiltError,
3332
RecoverySymbolError,
34-
Secondary,
3533
SecondaryRecoverySymbol,
3634
SecondarySliver,
3735
SliverVerificationError,
@@ -484,11 +482,10 @@ impl SliverPairIndex {
484482
/// # Panics
485483
///
486484
/// Panics if the index is greater than or equal to `n_shards`.
487-
pub fn to_sliver_index<E: EncodingAxis>(self, n_shards: NonZeroU16) -> SliverIndex {
488-
if E::IS_PRIMARY {
489-
self.into()
490-
} else {
491-
(n_shards.get() - self.0 - 1).into()
485+
pub fn to_sliver_index(self, n_shards: NonZeroU16, sliver_type: SliverType) -> SliverIndex {
486+
match sliver_type {
487+
SliverType::Primary => self.into(),
488+
SliverType::Secondary => (n_shards.get() - self.0 - 1).into(),
492489
}
493490
}
494491
}
@@ -503,11 +500,10 @@ impl SliverIndex {
503500
/// # Panics
504501
///
505502
/// Panics if the index is greater than or equal to `n_shards`.
506-
pub fn to_pair_index<E: EncodingAxis>(self, n_shards: NonZeroU16) -> SliverPairIndex {
507-
if E::IS_PRIMARY {
508-
self.into()
509-
} else {
510-
(n_shards.get() - self.0 - 1).into()
503+
pub fn to_pair_index(self, n_shards: NonZeroU16, sliver_type: SliverType) -> SliverPairIndex {
504+
match sliver_type {
505+
SliverType::Primary => self.into(),
506+
SliverType::Secondary => (n_shards.get() - self.0 - 1).into(),
511507
}
512508
}
513509
}
@@ -598,6 +594,11 @@ impl Sliver {
598594
by_axis::flat_map!(self.as_ref(), |x| x.verify(encoding_config, metadata))
599595
}
600596

597+
/// Returns the [`SliverIndex`] of this sliver (primary or secondary).
598+
pub fn sliver_index(&self) -> SliverIndex {
599+
by_axis::flat_map!(self.as_ref(), |x| x.index)
600+
}
601+
601602
/// Returns the [`Sliver<T>`][Sliver] contained within the enum.
602603
pub fn to_raw<T>(self) -> Result<encoding::SliverData<T>, WrongSliverVariantError>
603604
where
@@ -933,10 +934,7 @@ impl SliverId {
933934

934935
/// Returns the [`SliverPairIndex`] of the identified sliver.
935936
pub fn pair_index(&self, n_shards: NonZeroU16) -> SliverPairIndex {
936-
match self {
937-
ByAxis::Primary(value) => value.to_pair_index::<Primary>(n_shards),
938-
ByAxis::Secondary(value) => value.to_pair_index::<Secondary>(n_shards),
939-
}
937+
self.into_inner().to_pair_index(n_shards, self.r#type())
940938
}
941939
}
942940

crates/walrus-sdk/src/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ impl<E: EncodingAxis> SliverSelector<E> {
159159
let indices_and_shards = sliver_indices
160160
.iter()
161161
.map(|sliver_index| {
162-
let pair_index = sliver_index.to_pair_index::<E>(n_shards);
162+
let pair_index = sliver_index.to_pair_index(n_shards, E::sliver_type());
163163
let shard_index = pair_index.to_shard_index(n_shards, blob_id);
164164
(*sliver_index, shard_index)
165165
})

crates/walrus-service/src/node.rs

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,7 @@ use walrus_core::{
7575
EncodingAxis,
7676
EncodingConfig,
7777
GeneralRecoverySymbol,
78-
Primary,
7978
RecoverySymbolError,
80-
Secondary,
8179
SliverData,
8280
source_symbols_for_n_shards,
8381
},
@@ -3339,11 +3337,11 @@ impl StorageNodeInner {
33393337

33403338
let primary_index = symbol_id.primary_sliver_index();
33413339
self.check_index(primary_index)?;
3342-
let primary_pair_index = primary_index.to_pair_index::<Primary>(n_shards);
3340+
let primary_pair_index = primary_index.to_pair_index(n_shards, SliverType::Primary);
33433341

33443342
let secondary_index = symbol_id.secondary_sliver_index();
33453343
self.check_index(secondary_index)?;
3346-
let secondary_pair_index = secondary_index.to_pair_index::<Secondary>(n_shards);
3344+
let secondary_pair_index = secondary_index.to_pair_index(n_shards, SliverType::Secondary);
33473345

33483346
let owned_shards = self.owned_shards_at_latest_epoch();
33493347

@@ -3480,11 +3478,12 @@ impl StorageNodeInner {
34803478
match *target_type {
34813479
SliverType::Primary => SymbolId::new(
34823480
*target,
3483-
pair_stored.to_sliver_index::<Secondary>(n_shards),
3481+
pair_stored.to_sliver_index(n_shards, SliverType::Secondary),
3482+
),
3483+
SliverType::Secondary => SymbolId::new(
3484+
pair_stored.to_sliver_index(n_shards, SliverType::Primary),
3485+
*target,
34843486
),
3485-
SliverType::Secondary => {
3486-
SymbolId::new(pair_stored.to_sliver_index::<Primary>(n_shards), *target)
3487-
}
34883487
}
34893488
};
34903489
Either::Right(self.owned_shards_at_latest_epoch().into_iter().map(map_fn))
@@ -3989,6 +3988,13 @@ impl ServiceState for StorageNodeInner {
39893988
intent: UploadIntent,
39903989
) -> Result<bool, StoreSliverError> {
39913990
self.check_index(sliver_pair_index)?;
3991+
let n_shards = self.n_shards();
3992+
let sliver_index = sliver.sliver_index();
3993+
ensure!(
3994+
sliver_pair_index == sliver_index.to_pair_index(n_shards, sliver.r#type()),
3995+
anyhow!("sliver index mismatch").into()
3996+
);
3997+
39923998
let (metadata_persisted, persisted) = self
39933999
.resolve_metadata_for_sliver(&blob_id, intent.is_pending())
39944000
.await?;
@@ -8905,7 +8911,7 @@ mod tests {
89058911
)?;
89068912

89078913
let target_pair_index =
8908-
target_sliver_index.to_pair_index::<Primary>(n_shards_nonzero);
8914+
target_sliver_index.to_pair_index(n_shards_nonzero, SliverType::Primary);
89098915
let expected_sliver = &blob_detail[0]
89108916
.pairs
89118917
.iter()
@@ -8932,7 +8938,7 @@ mod tests {
89328938
)?;
89338939

89348940
let target_pair_index =
8935-
target_sliver_index.to_pair_index::<Secondary>(n_shards_nonzero);
8941+
target_sliver_index.to_pair_index(n_shards_nonzero, SliverType::Secondary);
89368942
let expected_sliver = &blob_detail[0]
89378943
.pairs
89388944
.iter()

crates/walrus-service/src/node/committee/request_futures.rs

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,12 +272,7 @@ where
272272
shared: &'a NodeCommitteeServiceInner<T>,
273273
) -> Self {
274274
Self {
275-
target_index: match target_sliver_type {
276-
SliverType::Primary => sliver_id.to_sliver_index::<Primary>(metadata.n_shards()),
277-
SliverType::Secondary => {
278-
sliver_id.to_sliver_index::<Secondary>(metadata.n_shards())
279-
}
280-
},
275+
target_index: sliver_id.to_sliver_index(metadata.n_shards(), target_sliver_type),
281276
target_sliver_type,
282277
epoch_certified,
283278
backoff: ExponentialBackoffState::new_infinite(
@@ -668,10 +663,10 @@ impl<'a, T: NodeService> CollectRecoverySymbols<'a, T> {
668663
match self.target_sliver_type() {
669664
SliverType::Primary => SymbolId::new(
670665
self.target_index(),
671-
pair_at_shard.to_sliver_index::<Secondary>(n_shards),
666+
pair_at_shard.to_sliver_index(n_shards, SliverType::Secondary),
672667
),
673668
SliverType::Secondary => SymbolId::new(
674-
pair_at_shard.to_sliver_index::<Primary>(n_shards),
669+
pair_at_shard.to_sliver_index(n_shards, SliverType::Primary),
675670
self.target_index(),
676671
),
677672
}

crates/walrus-service/src/node/committee/test_committee_service.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ fn recovery_symbols_by_shard(
448448
let blob = walrus_test_utils::random_data(314);
449449
let n_shards = NonZero::new(n_shards).unwrap();
450450
let target_sliver_index = SliverIndex(0);
451-
let target_sliver_pair_index = target_sliver_index.to_pair_index::<Primary>(n_shards);
451+
let target_sliver_pair_index = target_sliver_index.to_pair_index(n_shards, SliverType::Primary);
452452

453453
let encoding_config = EncodingConfig::new(n_shards);
454454
let encoding_config_enum = encoding_config.get_for_type(DEFAULT_ENCODING);

crates/walrus-service/src/node/recovery_symbol_service.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ impl RecoverySymbolService {
200200
{
201201
let sliver_ref = sliver.as_ref();
202202
let target_sliver_index =
203-
target_pair_index.to_sliver_index::<T::OrthogonalAxis>(config.n_shards());
203+
target_pair_index.to_sliver_index(config.n_shards(), T::OrthogonalAxis::sliver_type());
204204
let is_source_target = usize::from(target_sliver_index.get()) < sliver_ref.symbols.len();
205205

206206
if is_source_target {

0 commit comments

Comments
 (0)